Stabilizing Adversarial Nets With Prediction Methods
Abstract
Adversarial neural networks solve many important problems in data science, but are notoriously difficult to train. These difficulties come from the fact that optimal weights for adversarial nets correspond to saddle points, and not minimizers, of the loss function. The alternating stochastic gradient methods typically used for such problems do not reliably converge to saddle points, and when convergence does happen it is often highly sensitive to learning rates. We propose a simple modification of stochastic gradient descent that stabilizes adversarial networks. We show, both in theory and practice, that the proposed method reliably converges to saddle points, and is stable with a wider range of training parameters than a nonprediction method. This makes adversarial networks less likely to “collapse,” and enables faster training with larger learning rates.
Stabilizing Adversarial Nets With Prediction Methods
Abhay Yadav^{†}^{†}thanks: Equal contribution,Sohil Shah^{†}^{†}footnotemark: ,Zheng Xu,David Jacobs, & Tom Goldstein 

University of Maryland, College Park 
College Park, MD 20740, USA 
{jaiabhay, xuzh, tomg}@cs.umd.edu, sohilas@umd.edu, 
djacobs@umiacs.umd.edu 
1 Introduction
Adversarial networks play an important role in a variety of applications, including image generation (Zhang et al., 2017; Wang & Gupta, 2016), style transfer (Brock et al., 2017; Taigman et al., 2017; Wang & Gupta, 2016; Isola et al., 2017), domain adaptation (Taigman et al., 2017; Tzeng et al., 2017; Ganin & Lempitsky, 2015), imitation learning (Ho et al., 2016), privacy (Edwards & Storkey, 2016; Abadi & Andersen, 2016), fair representation (Mathieu et al., 2016; Edwards & Storkey, 2016), etc. One particularly motivating application of adversarial nets is their ability to form generative models, as opposed to the classical discriminative models (Goodfellow et al., 2014; Radford et al., 2016; Denton et al., 2015; Mirza & Osindero, 2014).
While adversarial networks have the power to attack a wide range of previously unsolved problems, they suffer from a major flaw: they are difficult to train. This is because adversarial nets try to accomplish two objectives simultaneously; weights are adjusted to maximize performance on one task while minimizing performance on another. Mathematically, this corresponds to finding a saddle point of a loss function  a point that is minimal with respect to one set of weights, and maximal with respect to another.
Conventional neural networks are trained by marching down a loss function until a minimizer is reached (Figure (a)a). In contrast, adversarial training methods search for saddle points rather than a minimizer, which introduces the possibility that the training path “slides off” the objective functions and the loss goes to (Figure (b)b), resulting in “collapse” of the adversarial network. As a result, many authors suggest using early stopping, gradients/weight clipping (Arjovsky et al., 2017), or specialized objective functions (Goodfellow et al., 2014; Zhao et al., 2017; Arjovsky et al., 2017) to maintain stability.
In this paper, we present a simple “prediction” step that is easily added to many training algorithms for adversarial nets. We present theoretical analysis showing that the proposed prediction method is asymptotically stable for a class of saddle point problems. Finally, we use a wide range of experiments to show that prediction enables faster training of adversarial networks using large learning rates without the instability problems that plague conventional training schemes.
2 Proposed Method
Saddlepoint optimization problems have the general form
(1) 
for some loss function and variables and Most authors use the alternating stochastic gradient method to solve saddlepoint problems involving neural networks. This method alternates between updating with a stochastic gradient descent step, and then updating with a stochastic gradient ascent step. When simple/classical SGD updates are used, the steps of this method can be written
(2) 
Here, and are learning rate schedules for the minimization and maximization steps, respectively. The vectors and denote (possibly stochastic) gradients of with respect to and . In practice, the gradient updates are often performed by an automated solver, such as the Adam optimizer (Kingma & Ba, 2015), and include momentum updates.
We propose to stabilize the training of adversarial networks by adding a prediction step. Rather than calculating using we first make a prediction, about where the iterates will be in the future, and use this predicted value to obtain
Prediction Method
(3) 
The Prediction step (3) tries to estimate where is going to be in the future by assuming its trajectory remains the same as in the current iteration.
3 Background
3.1 Adversarial Networks as a SaddlePoint Problem
We now discuss a few common adversarial network problems and their saddlepoint formulations. Generative Adversarial Networks (GANs) fit a generative model to a dataset using a game in which a generative model competes against a discriminator (Goodfellow et al., 2014). The generator, takes random noise vectors as inputs, and maps them onto points in the target data distribution. The discriminator, accepts a candidate point and tries to determine whether it is really drawn from the empirical distribution (in which case it outputs 1), or fabricated by the generator (output 0). During a training iteration, noise vectors from a Gaussian distribution are pushed through the generator network to form a batch of generated data samples denoted by A batch of empirical samples, is also prepared. One then tries to adjust the weights of each network to solve a saddle point problem, which is popularly formulated as,
(4) 
Here is any monotonically increasing function. Initially, (Goodfellow et al., 2014) proposed using .
Domain Adversarial Networks (DANs) (Makhzani et al., 2016; Ganin & Lempitsky, 2015; Edwards & Storkey, 2016) take data collected from a “source” domain, and extract a feature representation that can be used to train models that generalize to another “target” domain. For example, in the domain adversarial neural network (DANN (Ganin & Lempitsky, 2015)), a set of feature layers maps data points into an embedded feature space, and a classifier is trained on these embedded features. Meanwhile, the adversarial discriminator tries to determine, using only the embedded features, whether the data points belong to the source or target domain. A good embedding yields a better taskspecific objective on the target domain while fooling the discriminator, and is found by solving
(5) 
Here is any adversarial discriminator loss function and denotes the task specific loss. and are network parameter of feature mapping, discriminator, and classification layers.
3.2 Stabilizing saddle point solvers
It is well known that alternating stochastic gradient methods are unstable when using simple logarithmic losses. This led researchers to explore multiple directions for stabilizing GANs; either by adding regularization terms (Arjovsky et al., 2017; Li et al., 2015; Che et al., 2017; Zhao et al., 2017), a myriad of training “hacks” (Salimans et al., 2016; Gulrajani et al., 2017), reengineering network architectures (Zhao et al., 2017), and designing different solvers (Metz et al., 2017). Specifically, the Wasserstein GAN (WGAN) (Arjovsky et al., 2017) approach modifies the original objective by replacing with This led to a training scheme in which the discriminator weights are “clipped.” However, as discussed in Arjovsky et al. (2017), the WGAN training is unstable at high learning rates, or when used with popular momentum based solvers such as Adam. Currently, it is known to work well only with RMSProp (Arjovsky et al., 2017).
The unrolled GAN (Metz et al., 2017) is a new solver that can stabilize training at the cost of more expensive gradient computations. Each generator update requires the computation of multiple extra discriminator updates, which are then discarded when the generator update is complete. While avoiding GAN collapse, this method requires increased computation and memory.
In the convex optimization literature, saddle point problems are more well studied. One popular solver is the primaldual hybrid gradient (PDHG) method (Zhu & Chan, 2008; Esser et al., 2009), which has been popularized by Chambolle and Pock (Chambolle & Pock, 2011), and has been successfully applied to a range of machine learning and statistical estimation problems (Goldstein et al., 2015). PDHG relates closely to the method proposed here  it achieves stability using the same prediction step, although it uses a different type of gradient update and is only applicable to bilinear problems.
Stochastic methods for convex saddlepoint problems can be roughly divided into two categories: stochastic coordinate descent (Dang & Lan, 2014; Lan & Zhou, 2015; Zhang & Lin, 2015; Zhu & Storkey, 2015; 2016; Wang & Xiao, 2017; Shibagaki & Takeuchi, 2017) and stochastic gradient descent (Chen et al., 2014; Qiao et al., 2016). Similar optimization algorithms have been studied for reinforcement learning (Wang & Chen, 2016; Du et al., 2017). Recently, a “doubly” stochastic method that randomizes both primal and dual updates was proposed for strongly convex bilinear saddle point problems (Yu et al., 2015). For general saddle point problems, “doubly” stochastic gradient descent methods are discussed in Nemirovski et al. (2009),Palaniappan & Bach (2016), in which primal and dual variables are updated simultaneously based on the previous iterates and the current gradients.
4 Interpretations of the prediction step
We present three ways to explain the effect of prediction: an intuitive, nonmathematical perspective, a more analytical viewpoint involving dynamical systems, and finally a rigorous proofbased approach.
4.1 An intuitive viewpoint
The standard alternating SGD switches between minimization and maximization steps. In this algorithm, there is a risk that the minimization step can overpower the maximization step, in which case the iterates will “slide off” the edge of saddle, leading to instability (Figure (b)b). Conversely, an overpowering maximization step will dominate the minimization step, and drive the iterates to extreme values as well.
The effect of prediction is visualized in Figure 6. Suppose that a maximization step takes place starting at the red dot. Without prediction, the maximization step has no knowledge of the algorithm history, and will be the same regardless of whether the previous minimization update was weak (Figure (a)a) or strong (Figure (b)b). Prediction allows the maximization step to exploit information about the minimization step. If the previous minimizations step was weak (Figure (a)a), the prediction step (dotted black arrow) stays close to the red dot, resulting in a weak predictive maximization step (white arrow). But if we arrived at the red dot using a strong minimization step (Figure (b)b), the prediction moves a long way down the loss surface, resulting in a stronger maximization step (white arrows) to compensate.
4.2 A more mathematical perspective
To get stronger intuition about prediction methods, let’s look at the behavior of Algorithm (3) on a simple bilinear saddle of the form
(6) 
where is a matrix. When exact (nonstochastic) gradient updates are used, the iterates follow the path of a simple dynamical system with closedform solutions. We give here a sketch of this argument: a detailed derivation is provided in the Supplementary Material.
When the (nonpredictive) gradient method (2) is applied to the linear problem (6), the resulting iterations can be written
When the stepsize gets small, this behaves like a discretization of the system of differential equations
where and denote the derivatives of and with respect to time. These equations describe a simple harmonic oscillator, and the closed form solution for is
where is a diagonal matrix, and the matrix and vector depend on the initialization. We can see that, for small values of and the nonpredictive algorithm (2) approximates an undamped harmonic motion, and the solutions orbit around the saddle without converging.
The prediction step (3) improves convergence because it produces damped harmonic motion that sinks into the saddle point. When applied to the linearized problem (6), we get the dynamical system
(7) 
which has solution
From this analysis, we see that the damping caused by the prediction step causes the orbits to converge into the saddle point, and the error decays exponentially fast.
4.3 A rigorous perspective
While the arguments above are intuitive, they are also informal and do not address issues like stochastic gradients, nonconstant stepsize sequences, and more complex loss functions. We now provide a rigorous convergence analysis that handles these issues.
We assume that the function is convex in and concave in . We can then measure convergence using the “primaldual” gap, where is a saddle. Note that for nonoptimal and if is a saddle. Using these definitions, we formulate the following convergence result. The proof is in the supplementary material.
Theorem 1.
Suppose the function is convex in concave in and that the partial gradient is uniformly Lipschitz smooth in (). Suppose further that the stochastic gradient approximations satisfy for scalars and and that and for scalars and
If we choose decreasing learning rate parameters of the form and then the SGD method with prediction converges in expectation, and we have the error bound
where
5 Experiments
We present a wide range of experiments to demonstrate the benefits of the proposed prediction step for adversarial nets. We consider a saddle point problem on a toy dataset constructed using MNIST images, and then move on to consider stateoftheart models for three tasks: GANs, domain adaptation, and learning of fair classifiers. Additional results, and additional experiments involving mixtures of Gaussians, are presented in the Appendix.
5.1 MNIST Toy problem
We consider the task of classifying MNIST digits as being even or odd. To make the problem interesting, we corrupt 70% of odd digits with saltandpepper noise, while we corrupt only 30% of even digits. When we train a LeNet network (LeCun et al., 1998) on this problem, we find that the network encodes and uses information about the noise; when a noise vs nonoise classifier is trained on the deep features generated by LeNet, it gets 100% accuracy. The goal of this task is to force LeNet to ignore the noise when making decisions. We create an adversarial model of the form (5) in which is a softmax loss for the even vs odd classifier. We make a softmax loss for the task of discriminating whether the input sample is noisy or not. The classifier and discriminator were both pretrained using the default LeNet implementation in Caffe (Jia et al., 2014). Then the combined adversarial net was jointly trained both with and without prediction. For implementation details, see the Supplementary Material.
Figure 10 summarizes our findings. In this experiment, we considered applying prediction to both the classifier and discriminator. We note that our task is to retain good classification accuracy while preventing the discriminator from doing better than the trivial strategy of classifying odd digits as noisy and even digits as nonnoisy. This means that the discriminator accuracy should ideally be . As shown in Figure (a)a, the prediction step hardly makes any difference when evaluated at the small learning rate of . However, when evaluated at higher rates, Figures (b)b and (c)c show that the prediction solvers are very stable while one without prediction collapses (blue solid line is flat) very early. Figure (c)c shows that the default learning rate () of the Adam solver is unstable unless prediction is used.
5.2 Generative Adversarial Networks
Next, we test the efficacy and stability of our proposed predictive step on generative adversarial networks (GAN), which are formulated as saddle point problems (4) and are popularly solved using a heuristic approach (Goodfellow et al., 2014). We consider an image modeling task using CIFAR10 (Krizhevsky, 2009) on the recently popular convolutional GAN architecture, DCGAN (Radford et al., 2016). We compare our predictive method with that of DCGAN and the unrolled GAN (Metz et al., 2017) using the training protocol described in Radford et al. (2016). Note that we compared against the unrolled GAN with stop gradient switch^{1}^{1}1We found the unrolled GAN without stop gradient switch as well as for smaller values of collapsed when used on the DCGAN architecture. and unrolling steps. All the approaches were trained for five random seeds and 100 epochs each.
We start with comparing all three methods using the default solver for DCGAN (the Adam optimizer) with learning rate= and =. Figure 17 compares the generated sample images (at the epoch) and the training loss curve for all approaches. The discriminator and generator loss curves in Figure (e)e show that without prediction, the DCGAN collapses at the and epochs. Similarly, Figure (f)f shows that the training for unrolled GAN collapses in at least three instances. The training procedure using predictive steps never collapsed during any epochs. Qualitatively, the images generated using prediction are more diverse than the DCGAN and unrolled GAN images.
Figure 24 compares all approaches when trained with higher learning rate () (the default for the Adam solver). As observed in Radford et al. (2016), the standard and unrolled solvers are very unstable and collapse at this higher rate. However, as shown in Figure (d)d, & (a)a, training remains stable when a predictive step is used, and generates images of reasonable quality. The training procedure for both DCGAN and unrolled GAN collapsed on all five random seeds. The results on various additional intermediate learning rates are in the Supplementary Material.
In the Supplementary Material, we present one additional comparison showing results on a higher momentum of = (learning rate=). We observe that all the training approaches are stable. However, the quality of images generated using DCGAN is inferior to that of the predictive and unrolled methods.
Overall, of the training settings we ran on (each of five learning rates for five random seeds), the DCGAN training procedure collapsed in such instances while unrolled GAN collapsed in experiments (not counting the multiple collapse in each training setting). On the contrary, we find that our simple predictive step method collapsed only once.
Note that prediction adds trivial cost to the training algorithm. Using a single TitanX Pascal, a training epoch of DCGAN takes 35 secs. With prediction, an epoch takes 38 secs. The unrolled GAN method, which requires extra gradient steps, takes 139 secs/epoch.
Finally, we draw quantitative comparisons based on the inception score (Salimans et al., 2016), which is a widely used metric for visual quality of the generated images. For this purpose, we consider the current stateoftheart Stacked GAN (Huang et al., 2017) architecture. Table 1 lists the inception scores computed on the generated samples from Stacked GAN trained ( epochs) with and without prediction at different learning rates. The joint training of Stacked GAN collapses when trained at the default learning rate of adam solver (i.e., ). However, reasonably good samples are generated if the same is trained with prediction on both the generator networks. The right end of Table 1 also list the inception score measured at fewer number of epochs at higher learning rates. It suggest that the model trained with prediction methods are not only stable but also allows faster convergence using higher learning rates. For reference the inception score on real images of CIFAR10 dataset is .
Learning rate  (40)  (20)  

Stacked GAN (joint)  
Stacked GAN (joint) + prediction 
5.3 Domain Adaptation
We consider the domain adaptation task (Saenko et al., 2010; Ganin & Lempitsky, 2015; Tzeng et al., 2017) wherein the representation learned using the source domain samples is altered so that it can also generalize to samples from the target distribution. We use the problem setup and hyperparameters as described in (Ganin & Lempitsky, 2015) using the Office dataset (Saenko et al., 2010) (experimental details are shared in the Supplementary Material). In Table 2, comparisons are drawn with respect to target domain accuracy on six pairs of sourcetarget domain tasks. We observe that the prediction step has mild benefits on the “easy” adaptation tasks with very similar source and target domain samples. However, on the transfer learning tasks of AmazontoWebcam, WebcamtoAmazon, and DslrtoAmazon which has noticeably distinct data samples, an extra prediction step gives an absolute improvement of in predicting target domain labels.
Method  Source  Amazon  Webcam  Dslr  Webcam  Amazon  Dslr 

Target  Webcam  Amazon  Webcam  Dslr  Dslr  Amazon  
DANN (Ganin & Lempitsky, 2015)  73.4  51.6  95.5  99.4  76.5  51.7  
DANN + prediction  74.7  58.5  96.1  99.0  73.5  57.6 
5.4 Fair Classifier
Finally, we consider a task of learning fair feature representations (Mathieu et al., 2016; Edwards & Storkey, 2016; Louizos et al., 2016) such that the final learned classifier does not discriminate with respect to a sensitive variable. As proposed in Edwards & Storkey (2016) one way to measure fairness is using discrimination,
(8) 
Here is a binary sensitive variable for the data sample and denotes the total number of samples belonging to the sensitive class. Similar to the domain adaptation task, the learning of each classifier can be formulated as a minimax problem in (5) (Edwards & Storkey, 2016; Mathieu et al., 2016). Unlike the previous example though, this task has a model selection component. From a pool of hundreds of randomly generated adversarial deep nets, for each value of , one selects the model that maximizes the difference
(9) 
The “Adult” dataset from the UCI machine learning repository is used. The task () is to classify whether a person earns /year. The person’s gender is chosen to be the sensitive variable. Details are in the supplementary. To demonstrate the advantage of using prediction for model selection, we follow the protocol developed in Edwards & Storkey (2016). In this work, the search space is restricted to a class of models that consist of a fully connected autoencoder, one task specific discriminator, and one adversarial discriminator. The encoder output from the autoencoder acts as input to both the discriminators. In our experiment, 100 models are randomly selected. During the training of each adversarial model, is a crossentropy loss while is a linear combination of reconstruction and crossentropy loss. Once all the models are trained, the best model for each value of is selected by evaluating (9) on the validation set.
Figure (a)a plots the results on the test set for the AFLR approach with and without prediction steps in their default Adam solver. For each value of , Figure (b)b, (c)c also compares the number of layers in the selected encoder and discriminator networks. When using prediction for training, relatively stronger encoder models are produced and selected during validation, and hence the prediction results generalize better on the test set.
6 Conclusion
We present a simple modification to the alternating SGD method, called a prediction step, that improves the stability of adversarial networks. We present theoretical results showing that the prediction step is asymptotically stable for solving saddle point problems. We show, using a variety of test problems, that prediction steps prevent network collapse and enable training with a wider range of learning rates than plain SGD methods.
Acknowledgments
The work of T. Goldstein was supported by the US Office of Naval Research under grant N000141712078, the US National Science Foundation (NSF) under grant CCF1535902, and by the Sloan Foundation. A. Yadav and D. Jacobs were supported by the National Science Foundation under grant no. IIS1526234 and by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA), via IARPA R&D Contract No. 201414071600012. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of the ODNI, IARPA, or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright annotation thereon.
References
 Abadi & Andersen (2016) Martín Abadi and David G Andersen. Learning to protect communications with adversarial neural cryptography. arXiv preprint arXiv:1610.06918, 2016.
 Arjovsky et al. (2017) Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In ICML, 2017.
 Brock et al. (2017) Andrew Brock, Theodore Lim, JM Ritchie, and Nick Weston. Neural photo editing with introspective adversarial networks. In ICLR, 2017.
 Chambolle & Pock (2011) Antonin Chambolle and Thomas Pock. A firstorder primaldual algorithm for convex problems with applications to imaging. Journal of Mathematical Imaging and Vision, 40(1):120–145, 2011.
 Che et al. (2017) Tong Che, Yanran Li, Athul Paul Jacob, Yoshua Bengio, and Wenjie Li. Mode regularized generative adversarial networks. In ICLR, 2017.
 Chen et al. (2014) Yunmei Chen, Guanghui Lan, and Yuyuan Ouyang. Optimal primaldual methods for a class of saddle point problems. SIAM Journal on Optimization, 24(4):1779–1814, 2014.
 Dang & Lan (2014) Cong Dang and Guanghui Lan. Randomized firstorder methods for saddle point optimization. arXiv preprint arXiv:1409.8625, 2014.
 Denton et al. (2015) Emily Denton, Soumith Chintala, Arthur Szlam, and Rob Fergus. Deep generative image models using a laplacian pyramid of adversarial networks. In NIPS, 2015.
 Du et al. (2017) Simon S Du, Jianshu Chen, Lihong Li, Lin Xiao, and Dengyong Zhou. Stochastic variance reduction methods for policy evaluation. ICML, 2017.
 Edwards & Storkey (2016) Harrison Edwards and Amos Storkey. Censoring representations with an adversary. In ICLR, 2016.
 Esser et al. (2009) Ernie Esser, Xiaoqun Zhang, and Tony Chan. A general framework for a class of first order primaldual algorithms for tv minimization. UCLA CAM Report, pp. 09–67, 2009.
 Ganin & Lempitsky (2015) Yaroslav Ganin and Victor Lempitsky. Unsupervised domain adaptation by backpropagation. In Proceedings of The 32nd International Conference on Machine Learning, pp. 1180–1189, 2015.
 Goldstein et al. (2015) Tom Goldstein, Min Li, and Xiaoming Yuan. Adaptive primaldual splitting methods for statistical learning and image processing. In NIPS, pp. 2089–2097, 2015.
 Goodfellow et al. (2014) Ian Goodfellow, Jean PougetAbadie, Mehdi Mirza, Bing Xu, David WardeFarley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pp. 2672–2680, 2014.
 Gulrajani et al. (2017) Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028, 2017.
 Ho et al. (2016) Jonathan Ho, Jayesh Gupta, and Stefano Ermon. Modelfree imitation learning with policy optimization. In International Conference on Machine Learning, pp. 2760–2769, 2016.
 Huang et al. (2017) Xun Huang, Yixuan Li, Omid Poursaeed, John Hopcroft, and Serge Belongie. Stacked generative adversarial networks. In CVPR, 2017.
 Isola et al. (2017) Phillip Isola, JunYan Zhu, Tinghui Zhou, and Alexei A Efros. Imagetoimage translation with conditional adversarial networks. In CVPR, 2017.
 Jia et al. (2014) Yangqing Jia, Evan Shelhamer, Jeff Donahue, Sergey Karayev, Jonathan Long, Ross Girshick, Sergio Guadarrama, and Trevor Darrell. Caffe: Convolutional architecture for fast feature embedding. In ACM Multimedia, pp. 675–678, 2014.
 Kingma & Ba (2015) Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
 Krizhevsky (2009) Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.
 Lan & Zhou (2015) Guanghui Lan and Yi Zhou. An optimal randomized incremental gradient method. arXiv preprint arXiv:1507.02000, 2015.
 LeCun et al. (1998) Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradientbased learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
 Li et al. (2015) Yujia Li, Kevin Swersky, and Richard S Zemel. Generative moment matching networks. In ICML, pp. 1718–1727, 2015.
 Louizos et al. (2016) Christos Louizos, Kevin Swersky, Yujia Li, Max Welling, and Richard Zemel. The variational fair autoencoder. In ICLR, 2016.
 Makhzani et al. (2016) Alireza Makhzani, Jonathon Shlens, Navdeep Jaitly, Ian Goodfellow, and Brendan Frey. Adversarial autoencoders. In ICLR, 2016.
 Mathieu et al. (2016) Michael F Mathieu, Junbo Jake Zhao, Junbo Zhao, Aditya Ramesh, Pablo Sprechmann, and Yann LeCun. Disentangling factors of variation in deep representation using adversarial training. In NIPS, pp. 5041–5049, 2016.
 Metz et al. (2017) Luke Metz, Ben Poole, David Pfau, and Jascha SohlDickstein. Unrolled generative adversarial networks. In ICLR, 2017.
 Mirza & Osindero (2014) Mehdi Mirza and Simon Osindero. Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784, 2014.
 Nemirovski et al. (2009) Arkadi Nemirovski, Anatoli Juditsky, Guanghui Lan, and Alexander Shapiro. Robust stochastic approximation approach to stochastic programming. SIAM Journal on optimization, 19(4):1574–1609, 2009.
 Odena et al. (2017) Augustus Odena, Christopher Olah, and Jonathon Shlens. Conditional image synthesis with auxiliary classifier gans. In ICLR, 2017.
 Palaniappan & Bach (2016) Balamurugan Palaniappan and Francis Bach. Stochastic variance reduction methods for saddlepoint problems. In NIPS, pp. 1408–1416, 2016.
 Qiao et al. (2016) Linbo Qiao, Tianyi Lin, YuGang Jiang, Fan Yang, Wei Liu, and Xicheng Lu. On stochastic primaldual hybrid gradient approach for compositely regularized minimization. In ECAI, 2016.
 Radford et al. (2016) Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016.
 Saenko et al. (2010) Kate Saenko, Brian Kulis, Mario Fritz, and Trevor Darrell. Adapting visual category models to new domains. ECCV, pp. 213–226, 2010.
 Salimans et al. (2016) Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. In NIPS, pp. 2234–2242, 2016.
 Shibagaki & Takeuchi (2017) Atsushi Shibagaki and Ichiro Takeuchi. Stochastic primal dual coordinate method with nonuniform sampling based on optimality violations. arXiv preprint arXiv:1703.07056, 2017.
 Taigman et al. (2017) Yaniv Taigman, Adam Polyak, and Lior Wolf. Unsupervised crossdomain image generation. In ICLR, 2017.
 Tzeng et al. (2017) Eric Tzeng, Judy Hoffman, Kate Saenko, and Trevor Darrell. Adversarial discriminative domain adaptation. In ICLR Workshop, 2017.
 Wang & Xiao (2017) Jialei Wang and Lin Xiao. Exploiting strong convexity from data with primaldual firstorder algorithms. ICML, 2017.
 Wang & Chen (2016) Mengdi Wang and Yichen Chen. An online primaldual method for discounted markov decision processes. In Decision and Control (CDC), 2016 IEEE 55th Conference on, pp. 4516–4521. IEEE, 2016.
 Wang & Gupta (2016) Xiaolong Wang and Abhinav Gupta. Generative image modeling using style and structure adversarial networks. In ECCV, pp. 318–335, 2016.
 Yu et al. (2015) Adams Wei Yu, Qihang Lin, and Tianbao Yang. Doubly stochastic primaldual coordinate method for empirical risk minimization and bilinear saddlepoint problem. arXiv preprint arXiv:1508.03390, 2015.
 Zhang et al. (2017) Han Zhang, Tao Xu, Hongsheng Li, Shaoting Zhang, Xiaogang Wang, Xiaolei Huang, and Dimitris Metaxas. Stackgan: Text to photorealistic image synthesis with stacked generative adversarial networks. In ICCV, 2017.
 Zhang & Lin (2015) Yuchen Zhang and Xiao Lin. Stochastic primaldual coordinate method for regularized empirical risk minimization. In ICML, pp. 353–361, 2015.
 Zhao et al. (2017) Junbo Zhao, Michael Mathieu, and Yann LeCun. Energybased generative adversarial network. In ICLR, 2017.
 Zhu & Chan (2008) Mingqiang Zhu and Tony Chan. An efficient primaldual hybrid gradient algorithm for total variation image restoration. UCLA CAM Report, pp. 08–34, 2008.
 Zhu & Storkey (2015) Zhanxing Zhu and Amos J Storkey. Adaptive stochastic primaldual coordinate descent for separable saddle point problems. In ECMLPKDD, pp. 645–658, 2015.
 Zhu & Storkey (2016) Zhanxing Zhu and Amos J Storkey. Stochastic parallel block coordinate descent for largescale saddle point problems. In AAAI, 2016.
Appendix
Appendix A Detailed derivation of the harmonic oscillator equation
Here, we provide a detailed derivation of the harmonic oscillator behavior of Algorithm (3) on the simple bilinear saddle of the form
where is a matrix. Note that, within a small neighborhood of a saddle, all smooth weakly convex objective functions behave like (6).To see why, consider a smooth objective function with a saddle point at Within a small neighborhood of the saddle, we can approximate the function to high accuracy using its Taylor approximation
where denotes the matrix of mixedpartial derivatives with respect to and . Note that the firstorder terms have vanished from this Taylor approximation because the gradients are zero at a saddle point. The and terms vanish as well because the problem is assumed to be weakly convex around the saddle. Up to thirdorder error (which vanishes quickly near the saddle), this Taylor expansion has the form (6). For this reason, stability on saddles of the form (6) is a necessary condition for convergence of (3), and the analysis here describes the asymptotic behavior of the prediction method on any smooth problem for which the method converges.
We will show that, as the learning rate gets small, the iterates of the nonprediction method (2) rotate in orbits around the saddle without converging. In contrast, the iterates of the prediction method fall into the saddle and converge.
When the conventional gradient method (2) is applied to the linear problem (6), the resulting iterations can be written
When the stepsize gets small, this behaves like a discretization of the differential equation
(10)  
(11) 
where and denote the derivatives of and with respect to time.
The differential equations (10,11) describe a harmonic oscillator. To see why, differentiate (10) and plug (11) into the result to get a differential equation in alone
(12) 
We can decompose this into a system of independent singlevariable problems by considering the eigenvalue decomposition We now multiply both sides of (12) by and make the change of variables to get
where is diagonal. This is the standard equation for undamped harmonic motion, and its solution is where acts entrywise, and the diagonal matrix and vector are constants that depend only on the initialization. Changing back into the variable , we get the solution
We can see that, for small values of and the nonpredictive algorithm (2) approximates an undamped harmonic motion, and the solutions orbit around the saddle without converging.
The prediction step (3) improves convergence because it produces damped harmonic motion that sinks into the saddle point. When applied to the linearized problem (6), the iterates of the predictive method (3) satisfy
For small this approximates the dynamical system
(13)  
(14) 
Like before, we differentiate (13) and use (14) to obtain
(15) 
Finally, multiply both sides by and perform the change of variables to get
This equation describes a damped harmonic motion. The solutions have the form Changing back to the variable we see that the iterates of the original method satisfy
where and depend on the initialization.
From this analysis, we see that for small constant the orbits of the lookahead method converge into the saddle point, and the error decays exponentially fast.
A Proof of Theorem 1
Assume the optimal solution exists, then . In the following proofs, we use to represent the stochastic approximation of gradients, where . We show the convergence of the proposed stochastic primaldual gradients for the primaldual gap . We prove the convergence rate in Theorem 1 by using Lemma 1 and Lemma 2, which present the contraction of primal and dual updates, respectively.
Lemma 1.
Suppose is convex in and , we have
(16) 
Proof.
Lemma 2.
Suppose is concave in and has Lipschitz gradients, ; and bounded variance, , ; and , we have
(21) 
Proof.
From the dual update in (3), we have
(22)  
(23) 
Take expectation on both sides of the equation, substitute and apply to get
(24) 
Reorganize (24) to get
(25) 
The right hand side of (25) can be represented as
(26)  
(27)  
(28) 
where
(29)  
(30)  
(31)  
(32)  
(33)  
(34)  
(35) 
Lipschitz smoothness is used for (31); the prediction step in (3) is used for (32); the primal update in (3) is used for (33); bounded assumptions are used for (35).
Since is concave in , we have
(36) 
We now present the proof of Theorem 1.
Appendix B MNIST Toy example
Experimental details: We consider a classic MNIST digits dataset (LeCun et al., 1998) consisting of 60,000 training images and 10,000 testing images each of size . For simplicity, let us consider a task (T1) of classifying into odd and even numbered images. Let’s say, that of data instances were corrupted using salt and pepper noise of probability 0.2 and this distortion process was biased. Specifically, only of even numbered images were distorted as against the of oddnumbered images. We have observed that any feature representation network trained using the binary classification loss function for task T1 has noise bias also encoded within it. This was verified by training an independent noise classifier on the learned features. This lead us to design of simple adversarial network to “unlearn” the noise bias from the feature learning pipeline. We formulate this using the minimax objective in (5).
In our model, is a softmax loss for the task (T2) of classifying whether the input sample is noisy or not. is a softmax loss for task T1 and . A LeNet network (LeCun et al., 1998) is considered for training on task T1 while a twolayer MLP is used for training on task T2. LeNet consist of two convolutional (conv) layers followed by two fully connected (FC) layers at the top. The parameters of conv layers form while that of FC and MLP layers forms and respectively. We train the network in three stages. Following the training on task T1, were fixed and MLP is trained independently on task T2. The default learning schedule of the LeNet implementation in Caffe (Jia et al., 2014) were followed for both the tasks. The total training iterations on each task were set to . After pretraining, the whole network is jointly finetuned using the adversarial approach. (5) is alternatively minimized w.r.t. and maximized w.r.t. . The predictive steps were only used during the finetuning phase.
Our finding is summarized in Figure 10. In addition, Figure 29 provides headtohead comparison of two popular solvers Adam and SGD using the predictive step. Not surprisingly, the Adam solver shows relatively better performance and convergence even with an additional predictive step. This also suggests that the default hyperparameter for the Adam solver can be retained and utilized for training this networks without resorting to any further hyperparameter tuning (as it is currently in practice).
Appendix C Domain Adaptation
Experimental details: To evaluate a domain adaptation task, we consider the Office dataset (Saenko et al., 2010). Office is a small scale dataset consisting of images collected from three distinct domains: Amazon, Dslr and Webcam. For such a small scale dataset, it is nontrivial to learn features from images of a single domain. For instance, consider the largest subset Amazon, which contains only 2,817 labeled images spread across 31 different categories. However, one can leverage the power of domain adaptation to improve cross domain accuracy. We follow the protocol listed in Ganin & Lempitsky (2015) and the same network architecture is used. Caffe (Jia et al., 2014) is used for implementation. The training procedure from Ganin & Lempitsky (2015) is kept intact except for the additional prediction step. In Table 2 comparisons are drawn with respect to target domain accuracy on three pairs of sourcetarget domain tasks. The test accuracy is reported at the end of 50,000 training iterations.
Appendix D Fair Classifier
Experimental details: The “Adult” dataset from the UCI machine learning repository is used, which consists of census data from people. The task is to classify whether a person earns /year. The person’s gender is chosen to be the sensitive variable. We binarize all the category attributes, giving us a total of 102 input features per sample. We randomly split data into 35,000 samples for training, 5000 for validation and 5000 for testing. The result reported here is an average over five such random splits.
Appendix E Generative Adversarial Networks
Toy Dataset: To illustrate the advantage of the prediction method, we experiment on a simple GAN architecture with fully connected layers using the toy dataset. The constructed toy example and its architecture is inspired by the one presented in Metz et al. (2017). The two dimensional data is sampled from the mixture of eight Gaussians with their means equally spaced around the unit circle centered at . The standard deviation of each Gaussian is set at . The two dimensional latent vector is sampled from the multivariate Gaussian distribution. The generator and discriminator networks consist of two fully connected hidden layers, each with hidden units and tanh activations. The final layer of the generator has linear activation while that of discriminator has sigmoid activation. The solver optimizes both the discriminator and the generator network using the objective in (4). We use adam solver with its default parameters (i.e., learning rate = , , ) and with input batch size of . The generated two dimensional samples are plotted in the figure (30). The straightforward utilization of the adam solver fails to construct all the modes of the underlying dataset while both unrolled GAN and our method are able to produce all the modes.
We further investigate the performance of GAN training algorithms on data sampled from a mixture of a large number of Gaussians. We use Gaussian modes which are equally spaced around a circle of radius centered at . We retain the same experimental settings as described above and train GAN with two different input batch sizes, a small and a large batch setting. The Figure (31) plots the generated sample output of GAN trained (for fixed number of epochs) under the above setting using different training algorithms. Note that for small batch size input, the default as well as the unrolled training for GAN fails to construct actual modes of the underlying dataset. We hypothesize that this is perhaps due to the batch size, , being smaller than the number of input modes . When trained with small batch the GAN observe samples only from few input modes at every iteration. This causes instability leading to the failure of training algorithms. This scenario is pertinent to real datasets wherein the number of modes are relatively high compared to input batch size.
DCGAN Architecture details: For our experiments, we use publicly available code for DCGAN (Radford et al., 2016) and their implementation for Cifar10 dataset. The random noise vector is of dimensional and output of the generator network is a x image of channels.
Additional DCGAN Results:
























Experiments on Imagenet: In this section we demonstrate the advantage of prediction methods for generating higher resolution images of size 128 x 128. For this purpose, the stateoftheart ACGAN (Odena et al., 2017) architecture is considered and conditionally learned using images of all 1000 classes from Imagenet dataset. We have used the publicly available code for ACGAN and all the parameter were set to it default as in Odena et al. (2017). The figure 60 plots the inception score measured at every training epoch of ACGAN model with and without prediction. The score is averaged over five independent runs. From the figure, it is clear that even at higher resolution with large number of classes the prediction method is stable and aids in speeding up the training.