内容简介:We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples ov
Learn to code a Few-shot Learning algorithm on the Omniglot dataset
Jun 24 ·5min read
Introduction
We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples overtakes any classical Machine Learning algorithm. A lot of people think the Human Kind is being overthrown by AI, but here is the truth: to be able to well differentiate classes, a classifier is often fed with several thousands of images per class… while we only need two or three!
Prototypical Networks is an algorithm introduced by Snell et al. in 2017 (in “Prototypical Networks for Few-shot Learning”) that addresses the Few-shot Learning paradigm. Let’s understand it step by step with an example. In this article, our goal is to classify images. The code provided is in PyTorch, available here.
The Omniglot dataset
In Few-shot Learning, we are given a dataset with few images per class (1 to 10 usually). In this article, we will work on the Omniglot dataset, which contains 1,623 different handwritten characters collected from 50 alphabets. This dataset can be found in this GitHub repository . I used the “images_background.zip” and the “images_evaluation.zip” files.
As suggested in the official paper, data augmentation is performed to increase the number of classes. In practice, all the images are rotated by 90°, 180° and 270°, each rotation resulting in an additional class. Once this data augmentation is performed, we have 1,623 * 4 = 6,492 classes. I split the whole dataset into a training set (images of 4,200 classes), and a testing set (images of 2,292 classes).
Select a sample
To create a sample, Nc classes are randomly picked among all classes. For each class we have two sets of images: the support set of size Ns and the query set of size Nq.
Embed the images
“Our approach is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class.” claim the authors of the original paper.
In other words, there exists a Mathematical representation of the images, in which images of the same class gather in groups called clusters. The main advantage of working in that embedding space is that two images that look the same will be close to each other, and two images that are completely different will be far away.
In our case, with the Omniglot dataset, the embedding block takes (28x28x3) images as inputs and returns column 64-dimensional points. The image2vector function is composed of 4 modules. Each module consists of a convolutional layer, a batch normalization, a ReLu activation function and a 2x2 max pooling layer.
Compute the class prototypes
In this step we compute a prototype for each cluster. Once the support images are embedded, vectors are averaged to form a class prototype, a kind of “delegate” for that class.
where v(k) is the prototype of class k, f_phi is the embedding function and xi are the support images.
Compute distances between queries and prototypes
This step consists in classifying the query images. To do so, we compute the distance between each image and the prototypes. Metric choice is crucial here, and the inventors of Prototypical Networks must be credited to their choice of distance: the Euclidean distance.
Once distances are computed, a softmax is performed over them to get probabilities of belonging to each class.
Compute the loss and backpropagate
Prototypical Networks learning phase proceeds by minimizing the negative log-probability, also called log-softmax loss. The main advantage of using a logarithm is to drastically increase the loss when the model fails to predict the right class.
The backpropagation is performed via Stochastic Gradient Descent.
Launch training
The whole sequence described above forms an episode. And the training phase contains several episodes. I tried to reproduce the results of the original paper. Here are the training settings:
- Nc: 60 classes
- Ns: 1 or 5 support points / class
- Nq: 5 query points / class
- 5 epochs
- 2000 episodes / epoch
- Learning Rate initially at 0.001 and divided by 2 at each epoch
The training took 30 min to run.
Results
Once the ProtoNet is trained, we can test it with new data. We select samples in the testing set in a similar way. The support set is used to compute de prototypes, and then each point of the query set is labelled according to the shorter distance to prototypes.
For the testing I tried 5-way and 20-way scenarios. I took the same number of support and query points than during the training phase. The tests were performed on 1000 episodes.
The results are presented in the table below. “5-way 1-shot” means Nc = 5 and Ns = 1.
I obtained similar results than the original paper, slightly better in some cases. This may be due to the sampling strategy which is not specified in the paper. I used random sampling at each episode.
以上所述就是小编给大家介绍的《Few-shot Learning with Prototypical Networks》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。