Implementing Batch Normalization in Python

栏目: IT技术 · 发布时间: 4年前

内容简介:I’m recently takingThe implementation of forward pass was relatively simple, but backpropagation, which is more challenging to deal with, took me quite some time to complete. After a few hours of work and struggle, I finally got over this challenge. Here I

Why and How You Implement Batch Normalization in Neural Network

Jan 30 ·4min read

I’m recently taking Convolutional Neural Networks for Visual Recognition offered by Stanford university online and just started working on the second assignment of this course. In one part of this assignment we are asked to implement a batch normalization in a fully connected neural network.

The implementation of forward pass was relatively simple, but backpropagation, which is more challenging to deal with, took me quite some time to complete. After a few hours of work and struggle, I finally got over this challenge. Here I would love to share some of my notes and thoughts on batch normalization.

So … what is batch normalization?

Batch normalization deals with the problem of poorly initialization of neural networks. It can be interpreted as doing preprocessing at every layer of the network . It forces the activations in a network to take on a unit gaussian distribution at the beginning of the training. This ensures that all neurons have about the same output distribution in the network and improves the rate of convergence.

To see why distribution of the activations in a network matters, you can refer to pp. 46 — 62 in the lecture slides offered by the course.

Let’s say we have a batch of activations x at a layer, the zero-mean unit-variance version of x is

Implementing Batch Normalization in Python

This is actually a differentiable operation, that’s why we can apply batch normalization in the training.

In the implementation, we insert the batch normalization layer right after a fully connected layer or a convolutional layer, and before nonlinear layers.

Forward pass of batch normalization

Implementing Batch Normalization in Python

algorithm of batch normalizing transform from the original paper

Let’s look at the gist from the original research paper .

As I said earlier, the whole concept of batch normalization is pretty easy to understand. After computing the mean and the variance of a batch of activations x , we can normalize x by the operation in the third line of the gist. Also note that we introduce learnable scale and shift parameters γ and β in case that zero-mean and unit-variance constraint is too hard for our network.

So the code for forward pass looks like:

One thing to pay attention to is that the estimations of mean and variance depend on the mini-batches we send into the network, and we can’t do this at test-time. So the mean μ and variance σ² for normalization at test-time are actually the running average of values we computed during training.

And this is also why batch normalization has a regularizing effect . We add some kind of randomness when training and average out this randomness at test-time to reduce generalization error (just like the effect of dropout ).

So this is my complete implementation for forward pass from the assignment 2 of the course:

Backpropagation

Now we want to derive a way to compute the gradients of batch normalization. What makes it challenging is the fact that μ itself is a function of x and σ² is a function of both μ and x . Thus we need to be extremely careful and clear when we are performing chain rule on this normalization function.

One of the things I found very helpful when taking the course is the concept of computational graph . It breaks down a complex function into several small and simple operations and helps you perform backpropagation in a neat, organized way (by deriving local gradients of each simple operation and multiplying them together to get the result).

Kratzert’s post explains every steps to compute the gradients of batch normalization using computational graph in detail. Check it out to understand more.

In Python we can write code like this:

One downside of this staged computation is that it takes much longer to derive the final gradients since we computed a lot of “ intermediate values” which might be cancelled out when multiplied together. To make everything faster we need to differentiate the function by ourselves to get a simple result.

When I was writing this article I found a post from Kevin’s blog that talks about every step to derive the gradients by chain rule. It has explained the details very clearly so please refer to it if you are interested in the derivation.

And finally here’s my implementation:

Summary

In this article, we learned how batch normalization improves convergence and why batch normalization serves as a kind of regularization . We also implemented forward pass and backpropagation for batch normalization in python.

Although you probably don’t need to worry about the implementation since everything is already there in those popular deep learning frameworks, I always believe that doing things on our own allows us to have a better understanding. Hope you have gained something after reading this article!


以上所述就是小编给大家介绍的《Implementing Batch Normalization in Python》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

轻资产创业

轻资产创业

蔡余杰 / 广东人民出版社 / 2017-11 / 45.00元

在互联网时代,资金和资源已经不是制约创业的关键因素。如今即便没有充足的资金和资产做后盾,创业梦依旧可以成为现实。相信轻资产创业模式能够帮助众多经营管理者和创业者实现管理与创业的梦想。 轻资产创业存在误区,如何跨过? 如何巧用四大模式让自媒体创业落地? 如何用一个点子引发创意型创业? 如何利用电商平台实现流量为王的营销型创业? 如何巧用知识节点做好知识产型创业? ......一起来看看 《轻资产创业》 这本书的介绍吧!

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

MD5 加密
MD5 加密

MD5 加密工具

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试