FIGR: Few-shot Image Generation with Reptile

FIGR: Few-shot Image Generation with Reptile

Louis Clouâtre
École Polytechnique de Montréal
louis.clouatre@polymtl.ca
   Marc Demers
McGill University
marc@cim.mcgill.ca
Abstract

Generative Adversarial Networks (GAN) boast impressive capacity to generate realistic images. However, like much of the field of deep learning, they require an inordinate amount of data to produce results, thereby limiting their usefulness in generating novelty. In the same vein, recent advances in meta-learning have opened the door to many few-shot learning applications. In the present work, we propose Few-shot Image Generation using Reptile (FIGR), a GAN meta-trained with Reptile. Our model successfully generates novel images on both MNIST and Omniglot with as little as 4 images from an unseen class. We further contribute FIGR-8, a new dataset for few-shot image generation, which contains icons categorized in over classes. Trained on FIGR-8, initial results show that our model can generalize to more advanced concepts (such as “bird” and “knife”) from as few as 8 samples from a previously unseen class of images and as little as 10 training steps through those 8 images. This work demonstrates the potential of training a GAN for few-shot image generation and aims to set a new benchmark for future work in the domain.

1 Introduction

Generative Adversarial Networks [7] have helped bridge the gap between human and artificial intelligence with regard to understanding and manipulating images. GANs however require several orders of magnitude more data points than humans in order to generate comprehensible images successfully from a given class of images. This impairs the ability of GANs to generate novelty. In many cases, if the data is abundant enough to successfully train a GAN, there is little purpose to generating more of this data.

On the other hand, recent advances in meta-learning, like the MAML [6] and Reptile [15] algorithms, have allowed learning tasks to perform well on novel data sampled from the same distribution as the training data. These meta-learning algorithms have seen direct applications in supervised and reinforcement learning, but not in image generation. Being very general in their application, those algorithms may be applicable to few-shot image generation. This paper defines the problem of few-shot image generation, and introduces an approach to GAN training for Few-shot Image Generation with Reptile (FIGR). In addition, this paper introduces FIGR-8, a dataset of black-and-white pictograms, ideograms, icons, emoticons, object or conception depictions categorized in classes. We contribute this dataset as a challenging benchmark for one- and few-shot image generation approaches. Following training, our approach is able to correctly generate images from a class of images with as few as samples from the previously unseen class.

In summary, our main contributions are:

  • We develop a novel approach for training GANs for few-shot image generation.

  • We contribute a challenging dataset for that same task.

The applications of few-shot image generation are broad, but we mainly foresee this approach to provide assistance in creative processes. Artists or designers who lack time or creative inspiration for multiple versions of an image could sketch a limited number of drawings and have the trained model generate multiple similar versions of the sketches.

2 Related work

2.1 Meta-learning

MAML is currently the most widely used approach for few-shot meta-learning. Several variant of the algorithm exist. They all have conditions that make them ill-fitting for meta-training a GAN. First, they rely on the direction of the loss function to be linked with the quality of the model. For GAN’s this assumption cannot be made. Second, they rely on being able to evaluate performance on a test set for training. There is no clear way to do that for GAN.

2.2 Few-Shot Image Generation

To our knowledge, Lake et al. (2015) [13] provides the first successful attempt at one-shot or few-shot image generation. To achieve this on the Omniglot dataset introduced in the same paper, both the images and stroke data are used to train a Bayesian model through Bayesian Program Learning. It represents concepts, such as a pen stroke, as simple probabilistic programs and hierarchically combines them to generate images. This yields a model that can be trained on a single image of a previously unseen letter and generate novel samples of the same letter. It generates binary images.

Rezende et al. (2016) [17] uses a sequential generative model to achieve one-shot generation. The inference process uses an attention [3] module to have a Variational Auto Encoder [12] attend to a section of the generated image sequentially. Unlike in Lake et al. (2015), it trains on pure image data (without requiring stroke data), making this approach much more general. It generates binary images of size and on the Omniglot dataset with one-shot learning.

Bartunov and Vetrov (2018) [4] uses matching networks to achieve few-shot image generation. In essence, matching networks [18] are memory-assisted networks that leverage an external memory by employing an attention [3] module to quickly learn new concepts. It assumes that the concepts stored are somewhat similar to the new out-of-sample concepts. This approach is equally trained on pure image data and does not require a lengthy sequential inference period. It generates binary images of size on the Omniglot dataset using few-shot learning.

Several issues can be found with the aforementioned approaches that no prior work seems to address:

  • The use of small binary images for all generative models seem to imply scalability issues.

  • Limitations to the Omniglot dataset for one- and few-shot image generation. This dataset has several issues that will be expanded up in Section 2.3

  • None of the approaches have use an architecture that has shown the potential to generate highly realistic images like GANs have.

2.3 Omniglot

The Omniglot dataset [13] is the current baseline dataset for the one- or few-shot image generation task. Details about the dataset can be found in Section 4.2. There are two main issues with using this dataset as a benchmark.

  • All classes within the dataset are very similar. They all represent roughly the same concept– a character.

  • The classes lack complexity. All classes in Omniglot are simple handwritten characters that can be explained and generated through the composition of learned pen strokes [13].

We believe that a proper image generation benchmark should encompass a greater variety of classes and more complex classes to have real-life applications or the hope of applications on natural images.

3 Few-shot Image Generation with Reptile

Generative Adversarial Networks GANs are generative models that learn a generator network to map a random noise vector to an image , such that . To accomplish this, we use a discriminator network and real images from the distribution we want to generate from . is trained on both and to be able to distinguish the ”fake” images from the ”real” images while is trained to fool . This adversarial game played between the two models leads to being able to generate images that resemble the ones from  [7].

Few-shot image generation We define the few-shot image generation problem with the help of the meta-learning problem set-up found in Finn et al. (2017) [6] and Nichol et al. (2018) [15]. In this problem we assume access to a set of tasks containing multiple task where each individual task is an image generation problem with one class of images and a loss . We define the ability of a human to discriminate between a group of generated images and a group of real images sampled from task as described in Lake et al. (2015) [13]. We do not conduct human benchmarking in this paper as this will be part of follow up work. We however leave it in the task description as we believe it is essential for a proper metric to exist.

The aim is to find, through meta-training, parameters , that can quickly, meaning with little data and little training, converge on a random task to minimize an associated loss .

In essence, we want to:

(1)

where is the operator that updates times using , a total of data points sampled from  [15].

MNIST As an example, the MNIST dataset contains 10 classes (the 10 digits). In the few-shot image generation problem, they represent 10 tasks to solve, to . We choose to to be the training task and to be the test task. Through meta-training on to , we aim to obtain a set of parameters that will quickly converge on a new . We choose to be 4, meaning that we aim for our meta-trained to converge to generating images of 9’s with only 4 images sampled from .

FIGR In FIGR, corresponds to both the generator network and the discriminator network . corresponds to one step of Stochastic Gradient Descent [5] on and using Wasserstein loss [1] with gradient-penalty [8].

The adapted Reptile pseudo code for meta-training the model is depicted in Algorithm 1. The algorithm is composed of an outer loop and an inner loop. The inner loop is the step of the operator on a copy of the parameters with task . Once we have those adapted weight , we can proceed to the outer loop. We set the gradient of to be equal to . We then take one step with the Adam optimizer [11].

1:Initialize , the discriminator parameter vector
2:Initialize , the generator parameter vector
3:for iteration 1, 2, 3 … do
4:     Make a copy of resulting in
5:     Make a copy of resulting in
6:     Sample task
7:     Sample images from resulting
8:     for  iterations do
9:         Generate latent vector
10:         Generate fake images with and
11:         Perform step of SGD update on with
12:           Wasserstein GP loss and and
13:         Generate latent vector
14:         Perform step of SGD update on with
15:           Wasserstein loss and
16:     end for
17:     Set gradient to be -
18:     Perform step of Adam update on
19:     Set gradient to be -
20:     Perform step of Adam update on
21:end for
Algorithm 1 Algorithm 1: FIGR training

Once meta-trained, we use a similar process to generate novel images from the sampled class described in Algorithm 2.

1:Using , a copy of the meta-trained
2:Using , a copy of the meta-trained
3:Sample test task
4:Sample images as from
5:for  iterations do
6:     Generate latent vector
7:     Generate fake images with and
8:     Perform step of SGD update on with
9:       Wasserstein GP loss and and
10:     Generate latent vector
11:     Perform step of SGD update on with
12:       Wasserstein loss and
13:end for
14:Generate latent vector
15:Generate fake images
Algorithm 2 Algorithm 2: FIGR generation

For every task there exist optimal discriminator and generator weights and . Intuitively, Reptile initializes the weights and to the point in parameter space that minimizes the distance between , , and for all , or

(2)

Hence, for a sampled task , a model optimized with Reptile can quickly and with few data points converge to the optimal point , from , . If the test tasks are close enough to the training task and if the training tasks are numerous enough, and are likely to be close to a test ’s and . This makes for rapid and easy generalization from few data points.

Figure 1: Sample taken from the FIGR-8 dataset. Items from out of classes are displayed and one class (cow) is (non-extensively) detailed

Reptile is broadly similar to joint training, and is effectively identical with a of 1. However, by doing more gradient steps, we prioritize learning features that would be hard to reach, unlike joint training. Assuming a 2D parameter space, a of 10 and a task ; a local minimum for parameter 1, , is reached after 2 gradient steps and a local minimum for the parameter 2, , is not reached after steps; it is probable that:

(3)

This would result in a larger outer loop update in the parameter space that is not readily attainable from and smaller updates in the parameter space in which the model already possesses the ability to converge quickly.

4 Datasets

4.1 Mnist

MNIST [14] is the first dataset chosen as its simplicity allows us to iterate quickly through model ideas. The MNIST dataset contains grayscale images from the 10 digits. We use the training set images for all experiments.

4.2 Omniglot

Omniglot [13] is arguably the de facto dataset for few-shot image generation. It contains unique type of characters originating from 50 alphabets, each of which has been handwritten 1 time by 20 different individuals. Contrarily to MNIST, Omniglot allows for training our model on a much larger amount of classes of images, and test the out-of-sample performance of the model on a wider set of classes.

4.3 Figr-8

For the sake of testing the limits of our model, we compiled images separated in conceptually different classes, a set of data which we named FIGR-8. Each class contains at least images, up to a few thousands. The icons are black-and-white representations of objects, concepts, patterns or designs that have been created by designers and artists and compiled into one data set. classes out of are pictured in Figure 1. Each of those classes containing at least 8 images of a similar theme. Every image is of square format . The relative cumulative density of classes in the database is represented in Figure 2.

Figure 2: Relative cumulative density of the number of elements in each class in the FIGR-8 dataset

We expect this dataset to be more challenging for training the meta-learning model, as it contains a wide variety of samples inside each class and a substantial amount of classes. Hopefully, the large amount of classes will let the model quickly understand the underlying concept even if every sample from a class does not represent the class’ concept in the same manner. Some icons do have complex patterns and details, which poses a greater challenge than the existing datasets for one- or few-shot image generation tasks. All in all, the FIGR-8 dataset constitutes a tough yet achievable benchmark for few-shot image generation tasks.

5 Experiments

5.1 Model architecture

All models have been trained with Wasserstein loss [1] with gradient-penalty [8]. We have found that a simple DCGAN [16] with a binary cross-entropy loss trained with this setup yielded positive results on MNIST [14]. More complex datasets, such as Omniglot [13] and FIGR-8, were more challenging and required this loss function for the model to succeed. Both the generator and the discriminator are built with residual neural networks [10] with 18 layers. The discriminator uses layer normalization [2] as prescribed in Gulrajani et al. (2017) [8]. The generator also uses layer normalization since batch normalization requires running statistics which are incompatible with Reptile’s meta-update.

All rectified linear units are Parametric ReLU [9] (PReLU). PReLU is the authors’ preferred rectified linear activation function. However, any other rectified linear activation function should yield comparable results.

All images are resized with bilinear interpolation to or . All images are in grayscale format and normalized to have values constrained between and . No data augmentation was used. Results where sampled every meta-training steps and experiments took between and meta-training steps for results to converge. All experiments were run on a single Tesla V100 on Google Cloud Platform (GCP). Training a model for meta-training steps with on Omniglot took hours with this setup. Table 1 at the end of this paper shows hyperparameters for all experiments.

5.2 Empirical Validation

In contrast with prior work, our model works on grayscale images rather than binary images. Our model also works without an external memory, a lengthy sequential inference process or additional training data in the form of pen stroke information. We believe that our approach, being built on top of GANs, has the best capacity to generalize to more challenging problems.

Shown below are the results of generating unseen test classes on our three datasets. The first row of every figure that follows represents the training data (circled in red). The following three rows are images generated by the model fine-tuned on those data points for 10 gradient steps. All images present results on previously unseen test classes. If unspecified, .

MNIST The MNIST data was rescaled to 32x32 pixel. The training classes are the digits from 0 to 8. The test class is the digit 9.

Figure 3: MNIST; 50,000 update; 10 gradient steps

On Figure 3, we can see good results on MNIST after 50,000 meta-training steps. This validates our approach on a toy problem.

Omniglot The Omniglot data was resized to and . The training classes where all characters in the dataset minus randomly sampled character classes for the test set.

Figure 4: Omniglot; 140,000 update; 10 gradient steps
Figure 5: Omniglot; 230,000 update; 10 gradient steps

On simpler Omniglot characters like the one shown in Figure 4, the model converges to good results after meta-training steps. On more complex characters, even after meta-training steps results are still lacking and humans can easily distinguish between most generated characters and the real ones. This is pictured in Figure 5.

As for the images, a batch size of was required to generate good results. In this case, after meta-training steps, around half the generated characters could conceivably fool a human judge. This is pictured in Figure 6.

FIGR-8 The FIGR-8 data was resized to 32x32 pixels. The training classes where all classes minus randomly sampled classes for the test set. Here, was used for all experiments.

For the FIGR-8 dataset, arguably none of the generated images pictured in Figures 7, 8 and 9 can fool a human. We however see our model able to learn key features of the images very quickly, such as a birdlike shape or an ice cream cone.

Figure 6: Omniglot; 150,000 update; 10 gradient steps; ;
Figure 7: FIGR-8; 80,000 update; 10 gradient steps;
Figure 8: FIGR-8; 90,000 update; 10 gradient steps;
Figure 9: FIGR-8; 100,000 update; 10 gradient steps;

6 Conclusion

We have shown that Reptile can be used to effectively train Generative Adversarial Networks for few-shot image generation. Using meta-training on a dataset containing several similar classes of images, we can learn to generate images from an unseen class with as little as samples on MNIST and Omniglot datasets. This is done with no lengthy inference time, no external memory and no additional data. No hyperparameter tuning is required, the base parameters used are stable troughout experiments. It is, to our knowledge, the first GAN trained for few-shot image generation. Results show that our approach is able to quickly learn and generate simple concepts as well as complex ones. Preliminary results on FIGR-8 show that a complex concept such as “bird” can be learned. To date, no other few-shot image generation model has managed to generate images other than handwritten characters. The low amount of data required to generate images, once the model is pretrained, opens the door to several applications that were previously gated by the high amount of data required.

We have also built, and will release for open source use, FIGR-8, a dataset containing over different classes and over images. Hopefully, this dataset will become a strong benchmark in the task of few-shot image generation.

Several future directions should be explored:

  • Generating multi-channel and/or larger images, such as with the CIFAR-100 dataset or the ImageNet dataset.

  • Modifying batch normalization layers to be able to meta-train through them.

  • Exploiting the wide variety of GAN architectures available.

  • Using FIGR on ImageNet to make a pretrained GAN model for fine-tuning and transfer learning in the same capacity that ImageNet models are used for fine-tuning computer-vision models.

The code for the FIGR implementation can be found at https://github.com/OctThe16th/FIGR and the FIGR-8 database can be found at https://github.com/marcdemers/FIGR-8 and bit.ly/FIGR-8.

References

  • [1] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 214–223, International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR.
  • [2] Lei Jimmy Ba, Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. CoRR, abs/1607.06450, 2016.
  • [3] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv e-prints, abs/1409.0473, September 2014.
  • [4] Sergey Bartunov and Dmitry P. Vetrov. Few-shot generative modelling with generative matching networks. In AISTATS, 2018.
  • [5] Léon Bottou. Large-scale machine learning with stochastic gradient descent. In Yves Lechevallier and Gilbert Saporta, editors, Proceedings of COMPSTAT’2010, pages 177–186, Heidelberg, 2010. Physica-Verlag HD.
  • [6] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. CoRR, abs/1703.03400, 2017.
  • [7] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Z. Ghahramani, M. Welling, C. Cortes, N. D. Lawrence, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems 27, pages 2672–2680. Curran Associates, Inc., 2014.
  • [8] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville. Improved training of wasserstein gans, 2017.
  • [9] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. 2015 IEEE International Conference on Computer Vision (ICCV), Dec 2015.
  • [10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Jun 2016.
  • [11] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2014.
  • [12] Diederik P. Kingma and Max Welling. Auto-encoding variational bayes. CoRR, abs/1312.6114, 2013.
  • [13] Brenden M. Lake, Ruslan Salakhutdinov, Jason Gross, and Joshua B. Tenenbaum. One shot learning of simple visual concepts.
  • [14] Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010.
  • [15] Alex Nichol, Joshua Achiam, and John Schulman. On first-order meta-learning algorithms, 2018.
  • [16] Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks, 2015.
  • [17] Danilo Rezende, Shakir, Ivo Danihelka, Karol Gregor, and Daan Wierstra. One-shot generalization in deep generative models. In Maria Florina Balcan and Kilian Q. Weinberger, editors, Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 1521–1529, New York, New York, USA, 20–22 Jun 2016. PMLR.
  • [18] Oriol Vinyals, Charles Blundell, Tim Lillicrap, koray kavukcuoglu, and Daan Wierstra. Matching networks for one shot learning. In D. D. Lee, M. Sugiyama, U. V. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems 29, pages 3630–3638. Curran Associates, Inc., 2016.
MNIST Omniglot FIGR-8
Inner learning rate 0.0001 0.0001 0.0001
Outer learning rate 0.00001 0.00001 0.00001
Training size n 4 4 and 8 8
Inner loops K 10 10 10
Image resize and
Grayscale True True True
Validation classes 1 20 50
Table 1: Hyperparameters for all experiments
Comments 0
Request Comment
You are adding the first comment!
How to quickly get a good reply:
  • Give credit where it’s due by listing out the positive aspects of a paper before getting into which changes should be made.
  • Be specific in your critique, and provide supporting evidence with appropriate references to substantiate general statements.
  • Your comment should inspire ideas to flow and help the author improves the paper.

The better we are at sharing our knowledge with each other, the faster we move forward.
""
The feedback must be of minumum 40 characters
Add comment
Cancel
Loading ...
330148
This is a comment super asjknd jkasnjk adsnkj
Upvote
Downvote
""
The feedback must be of minumum 40 characters
The feedback must be of minumum 40 characters
Submit
Cancel

You are asking your first question!
How to quickly get a good answer:
  • Keep your question short and to the point
  • Check for grammar or spelling errors.
  • Phrase it like a question
Test
Test description