Building a ResNet in Keras

栏目: IT技术 · 发布时间: 4年前

内容简介:In principle, neural networks should get better results as they have more layers. A deeper network can learn anything a shallower version of itself can, plus (possibly) more than that. If, for a given dataset, there are no more things a network can learn b

Building a ResNet in Keras

Using Keras Functional API to construct a Residual Neural Network

What is a Residual Neural Network?

In principle, neural networks should get better results as they have more layers. A deeper network can learn anything a shallower version of itself can, plus (possibly) more than that. If, for a given dataset, there are no more things a network can learn by adding more layers to it, then it can just learn the identity mapping for those additional layers. In this way it preserves the information in the previous layers and can not do worst than shallower ones. A network should be able to learn at least the identity mapping if it doesn’t find something better than that.

But, in practice things are not like that. Deeper networks are harder to optimize. With each extra layer that we add to a network we add more difficulty in the process of training; it becomes harder for the optimization algorithm that we use to find the right parameters. As we add more layers, the network gets better results until at some point; then as we continue to add extra layers, the accuracy starts to drop.

Residual Networks attempt to solve this issue by adding the so called skip connections . A skip connection is depicted in the image above. As I said previously, deeper networks should be able to learn at least identity mappings; this is what skip connections do: they add identity mappings from one point in the network to a forward point, and then lets the network to learn just that extra 퐹(푥). If there are no more things the network can learn, then it just learns 퐹(푥) as being 0. It turns out that it is easier for the network to learn a mapping closer to 0 than the identity mapping.

A block with a skip connection as in the image above is called a residual block , and a Residual Neural Network (ResNet) is just a concatenation of such blocks.

An interesting fact is that our brains have structures similar to residual networks, for example cortical layer VI neurons get input from layer I, skipping intermediary layers.

A short introduction to Keras Functional API

If you are reading this, probably you are already familiar with the Sequential class which allows one to easily construct a neural network by just stacking layers one after another, like this:

from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential([
Dense(32, input_shape=(784,)),
Activation('relu'),
Dense(10),
Activation('softmax'),
])

But this way of building neural networks is not sufficient for our needs. With the Sequential class we can’t add skip connections. Keras also has the Model class, which can be used along with the functional API for creating layers to build more complex network architectures.

When constructed, the class keras.layers.Input returns a tensor object. A layer object in Keras can also be used like a function, calling it with a tensor object as a parameter. The returned object is a tensor that can then be passed as input to another layer, and so on.

As an example:

from keras.layers import Input, Dense
from keras.models import Model

inputs = Input(shape=(784,))
output_1 = Dense(64, activation='relu')(inputs)
output_2 = Dense(64, activation='relu')(output_1)
predictions = Dense(10, activation='softmax')(output_2)

model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels)

But the above code still constructs a network that is sequential, so no real use for this fancy functional syntax so far. The real use of this syntax is when using the so called Merge layers with which one can combine more input tensors. A few examples of these layers are: Add , Subtract , Multiply , Average . The one that we will need in building residual blocks is Add .

An example that uses Add :

from keras.layers import Input, Dense, Add
from keras.models import Model

input1 = Input(shape=(16,))
x1 = Dense(8, activation='relu')(input1)
input2 = Input(shape=(32,))
x2 = Dense(8, activation='relu')(input2)

added = Add()([x1, x2])

out = Dense(4)(added)
model = Model(inputs=[input1, input2], outputs=out)

This is by no means a comprehensive guide to Keras functional API. If you want to learn more please refer to the docs .

Let’s implement a ResNet

Next, we will implement a ResNet along with its plain (without skip connections) counterpart, for comparison.

The ResNet that we will build here has the following structure:

  • Input with shape (32, 32, 3)
  • 1 Conv2D layer, with 64 filters
  • 2, 5, 5, 2 residual blocks with 64, 128, 256, and 512 filters
  • AveragePooling2D layer with pool size = 4
  • Flatten layer
  • Dense layer with 10 output nodes

It has a total of 30 conv+dense layers. All the kernel sizes are 3×3. We use ReLU activation and BatchNormalization after conv layers.

The plain version is the same except for the skip connections.

We create first a helper function that takes a tensor as input and adds relu and batch normalization to it:

def relu_bn(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = BatchNormalization()(relu)
return bn

Then we create a function for constructing a residual block. It takes a tensor x as input and passes it through 2 conv layers; let’s call the output of these 2 conv layers as y . Then adds the input x to y , adds relu and batch normalization, and then returns the resulting tensor. When parameter downsample == True the first conv layer uses strides=2 to halve the output size and we use a conv layer with kernel_size=1 on input x to make it the same shape as y . The Add layer requires the input tensors to be of the same shape.

def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)

if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out

create_res_net() function puts everything together.

Here is the full code for this:

The plain network is constructed in a similar way, but it doesn’t have skip connections and we don’t use the residual_block() helper function; everything is done inside create_plain_net() .

The code for the plain network:

Training on CIFAR-10 and seeing the results

CIFAR-10 is a dataset of 32×32 rgb images over 10 categories. It contains 50k train images and 10k test images.

Below is a sample of 10 random images from each class:

We will train both ResNet and PlainNet on this dataset for 20 epochs, and then compare the results.

The training took about 55 min for each ResNet and PlainNet on a machine with 1 NVIDIA Tesla K80. There is no significant difference in training time between ResNet and PlainNet.

The results that we got are shown below.

So, we got an increase of 1.59% in validation accuracy by using a ResNet on this dataset. The difference should be bigger on deeper networks. Feel free to experiment and see the results that you get.

References

  1. Deep Residual Learning for Image Recognition
  2. Residual neural network — Wikipedia
  3. Guide to the Functional API — Keras documentation
  4. Model (functional API) — Keras documentation
  5. Merge Layers — Keras documentation
  6. CIFAR-10 and CIFAR-100 datasets

以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Ruby Cookbook

Ruby Cookbook

Lucas Carlson、Leonard Richardson / O'Reilly Media / 2006-7-29 / USD 49.99

Do you want to push Ruby to its limits? The "Ruby Cookbook" is the most comprehensive problem-solving guide to today's hottest programming language. It gives you hundreds of solutions to real-world pr......一起来看看 《Ruby Cookbook》 这本书的介绍吧!

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具