Few-shot Learning with Prototypical Networks

栏目: IT技术 · 发布时间: 4年前

内容简介:We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples ov

Learn to code a Few-shot Learning algorithm on the Omniglot dataset

Few-shot Learning with Prototypical Networks

Image credit: https://unsplash.com/photos/kZO9xqmO_TA

Introduction

We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples overtakes any classical Machine Learning algorithm. A lot of people think the Human Kind is being overthrown by AI, but here is the truth: to be able to well differentiate classes, a classifier is often fed with several thousands of images per class… while we only need two or three!

Prototypical Networks is an algorithm introduced by Snell et al. in 2017 (in “Prototypical Networks for Few-shot Learning”) that addresses the Few-shot Learning paradigm. Let’s understand it step by step with an example. In this article, our goal is to classify images. The code provided is in PyTorch, available here.

The Omniglot dataset

In Few-shot Learning, we are given a dataset with few images per class (1 to 10 usually). In this article, we will work on the Omniglot dataset, which contains 1,623 different handwritten characters collected from 50 alphabets. This dataset can be found in this GitHub repository . I used the “images_background.zip” and the “images_evaluation.zip” files.

Few-shot Learning with Prototypical Networks

Examples of characters found in the Omniglot dataset

As suggested in the official paper, data augmentation is performed to increase the number of classes. In practice, all the images are rotated by 90°, 180° and 270°, each rotation resulting in an additional class. Once this data augmentation is performed, we have 1,623 * 4 = 6,492 classes. I split the whole dataset into a training set (images of 4,200 classes), and a testing set (images of 2,292 classes).

Select a sample

To create a sample, Nc classes are randomly picked among all classes. For each class we have two sets of images: the support set of size Ns and the query set of size Nq.

Few-shot Learning with Prototypical Networks

Illustration of a sample of Nc classes, each containing a support set and a query set

Embed the images

“Our approach is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class.” claim the authors of the original paper.

In other words, there exists a Mathematical representation of the images, in which images of the same class gather in groups called clusters. The main advantage of working in that embedding space is that two images that look the same will be close to each other, and two images that are completely different will be far away.

In our case, with the Omniglot dataset, the embedding block takes (28x28x3) images as inputs and returns column 64-dimensional points. The image2vector function is composed of 4 modules. Each module consists of a convolutional layer, a batch normalization, a ReLu activation function and a 2x2 max pooling layer.

Few-shot Learning with Prototypical Networks

The 4 modules of the image2vector function

Compute the class prototypes

In this step we compute a prototype for each cluster. Once the support images are embedded, vectors are averaged to form a class prototype, a kind of “delegate” for that class.

Few-shot Learning with Prototypical Networks

where v(k) is the prototype of class k, f_phi is the embedding function and xi are the support images.

Few-shot Learning with Prototypical Networks

One prototype is computed per class

Compute distances between queries and prototypes

This step consists in classifying the query images. To do so, we compute the distance between each image and the prototypes. Metric choice is crucial here, and the inventors of Prototypical Networks must be credited to their choice of distance: the Euclidean distance.

Once distances are computed, a softmax is performed over them to get probabilities of belonging to each class.

Compute the loss and backpropagate

Prototypical Networks learning phase proceeds by minimizing the negative log-probability, also called log-softmax loss. The main advantage of using a logarithm is to drastically increase the loss when the model fails to predict the right class.

The backpropagation is performed via Stochastic Gradient Descent.

Launch training

The whole sequence described above forms an episode. And the training phase contains several episodes. I tried to reproduce the results of the original paper. Here are the training settings:

  • Nc: 60 classes
  • Ns: 1 or 5 support points / class
  • Nq: 5 query points / class
  • 5 epochs
  • 2000 episodes / epoch
  • Learning Rate initially at 0.001 and divided by 2 at each epoch

The training took 30 min to run.

Results

Once the ProtoNet is trained, we can test it with new data. We select samples in the testing set in a similar way. The support set is used to compute de prototypes, and then each point of the query set is labelled according to the shorter distance to prototypes.

For the testing I tried 5-way and 20-way scenarios. I took the same number of support and query points than during the training phase. The tests were performed on 1000 episodes.

The results are presented in the table below. “5-way 1-shot” means Nc = 5 and Ns = 1.

Few-shot Learning with Prototypical Networks
Obtained VS paper results

I obtained similar results than the original paper, slightly better in some cases. This may be due to the sampling strategy which is not specified in the paper. I used random sampling at each episode.


以上所述就是小编给大家介绍的《Few-shot Learning with Prototypical Networks》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

创新者的解答

创新者的解答

【美】克莱顿•克里斯坦森、【加】迈克尔·雷纳 / 中信出版社 / 2013-10-10 / 49.00

《创新者的解答》讲述为了追求创新成长机会,美国电信巨子AT&T在短短10年间,总共耗费了500亿美元。企业为了保持成功记录,会面对成长的压力以达成持续获利的目标。但是如果追求成长的方向出现偏误,后果往往比没有成长更糟。因此,如何创新,并选对正确方向,是每个企业最大的难题。 因此,如何创新,并导向何种方向,便在于创新结果的可预测性─而此可预测性则来自于正确的理论依据。在《创新者的解答》中,两位......一起来看看 《创新者的解答》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具