Proposing a new effect of learning rate decay — Network Stability
Uncovering learning rate as a form of regularisation in stochastic gradient descent
Jun 7 ·14min read
Abstract
Modern literature suggests the learning rate is the most important hyper-parameter to tune for a deep neural network. [1] With too low of a learning rate, gradient descent can be painfully slow down a long steep valley or saddle point. With too high of a learning rate, gradient descent risks overshooting the minima. [2] Adaptive learning rate algorithms have been developed to take into account the momentum and accumulated gradient to be more robust to these situations for non-convex optimisation problems. [3]
A common theme is that decaying the learning rate after a certain number of epochs can help models converge to better minima by allowing weights to settle into more exact sharp minima. The idea is that with a given learning rate you may continually miss the actual minima by going back and forth across it. So by decaying the learning rate, we allow our weights to settle into these sharp minima.
As the deep learning community continually improves and builds new ways of adjusting the learning rate, like cyclical learning rate schedules, I believe it’s important we understand more deeply what effects the learning rate has on our optimisation problem.
In this study, I would like to propose an over-looked effect of learning rate decay: network stability . I will experiment with learning rate decay in different settings and show how network stability arises from it, and how network stability can benefit subnetworks during the training process.
I believe learning rate decay has two effects on the optimisation problem:
- Allow weights to settle into deeper minima
- Provide network stability of forward propagated activations and back propagated loss signals during the training process
First, I would like to establish what I mean by network stability through some mathematical notation.
Then, I will proceed to show how both of these effects take place through experiments on the ResNet-18 architecture on the CIFAR-10 dataset.
Theory
What is network stability? More importantly, how does it affect the loss landscape?
Let us view our loss landscape probabilistically.
For a single example x, we can calculate the loss L, given the current weights of the network w,
An example x is picked from a distribution D over the space of all images.
We can then view the loss landscape probabilistically as
The loss function calculated for a single image x, for a given set of weights w, produces an array of gradients for each weight through back propagation, namely,
There is a corresponding probability that we observe this gradient from the probabilistic loss landscape above as
The process of performing stochastic gradient descent across a given dataset, with batches of samples, can be seen as navigating through the probabilistic loss field through point estimates.
Over a single batch of images, the weights are kept constant. Each update at time step t moves through the following conditional landscape,
By averaging over large enough batches of size n, we hope that
So that the gradient step we take, is an expected step, for a given set of weights.
Each step t is then taken over the following conditional:
By decreasing the learning rate over the network, we decrease our changes in w.
Now subsequent iterations of batch gradient descent operate over probabilistic loss fields that are more similar to the previous iteration. This enforces some level of stability in our network during training, as for small enough changes in w,
By decreasing the learning rate, we enforce that stochastic gradient descent operates over a tighter conditional probability over the weights , rather than jump between weight regimes and their corresponding conditional probability fields. This is my notion of network stability.
The classical view of learning rate decay as enabling convergence to sharp minima is an oversimplification of a stochastic process, and treats the loss landscape as a constant.
In the probabilistic view, I propose to view learning rate as a piece of the stochastic puzzle. We should treat the process of stochastic gradient descent… well, stochastically.
My hypothesis of the effects of stability in this probabilistic process are two-fold on deep layers and on early layers:
- When deep layers experience instability of forward propagated signals, they are forced to generalise to varying converged early subnetworks. When we introduce stability of early layers, deep layers can rely on stable forward propagated signals to develop stable hierarchal features.
- When early layers experience instability of back propagated loss signals, they don’t get a clear static picture of the loss. Instead, with each iteration of gradient descent, the image of the loss presented to them via deep layers changes. When we introduce stability of deeper layers, early layers can converge to a solution more effectively through SGD.
Experiments
For our experiments, we’ll consider the ResNet architecture on the CIFAR-10 dataset.
Usually, learning rate decay is applied after a certain number of epochs, which combines the effects of settling into sharp minima and providing network stability.
In a series of 4 experiments, I would like to separate these effects individually, show how they may be combined, and ultimately show that they are both a part of learning rate decay.
Experiment A: Sharp Minima without Network Stability
In our first experiment, we’ll study the effect of learning rate decay on individual blocks.
The ResNet architecture comprises of four filter blocks of sizes 64, 128, 256, and 512.
In this experiment, we’ll train an outpost model where learning rate is decayed network-wide followed by models that only decay the learning rate in a certain filter block. This experiment should reveal the extent to which individual block settles into sharper minima and its ability to do so without adding more stability to the rest of the network, since we keep the higher learning rate for the rest of the network.
Experiment B: Network Stability without Sharp Minima
In the second experiment, we’ll study the effect of freezing blocks to enforce stability while maintaining a constant learning rate across the non-frozen block. By maintaining a constant learning rate, we don’t introduce the effects of settling into sharper minima. We hope to see faster convergence to decreased loss when freezing blocks, as we do with network-wide LR decay.
Experiment C: Combining Stability with Sharp Minima
In this experiment, we’ll combine the effects of the above two experiments by employing learning rate decay for a single block and then freezing the other blocks entirely. This should allow for the decayed block to settle into sharper minima, and then further reap benefits of stability by freezing the other blocks.
Experiment D: No Compounding Effects
In this experiment, we will decay the learning rate across the entire network, then attempt to introduce stability effects by freezing all blocks except one. If no further stability effects are found as a result of the network-wide LR decay, we have effectively shown that LR decay substitutes and provides for the gains realised by freezing effects.
Experiment A: Sharp Minima
We will train the following 5 models:
- Outpost Model with step-wise LR decay
LR starts at 0.1. Decays to 0.01 at epoch 60, then decays to 0.001 at epoch 120.
2. Filter block 64 LR Decay
3. Filter block 128 LR Decay
4. Filter block 256 LR Decay
5. Filter block 512 LR Decay
Our hypothesis is that each block should be able to settle into some lower minima. We should expect to see training loss decrease as we decay the learning rate, and top-1 validation accuracy increase.
The results are as follows:
Zooming into the Top 1 Train Accuracy graph and Top 1 Validation Accuracy graph,
We observe the following:
- On our outpost model, the LR decay at 60 epochs and 120 epochs allows the model to converge to a lower loss on the training set and higher accuracy on the validation set.
- On our filter block decay models, we see that points of LR decay help the filter block settle into some deeper minima.
- The outpost model converges to a stable validation accuracy, with less variance across iterations of SGD than the filter decay counterparts.
From this experiment, we see that LR decay on each block is able to contribute to potentially settle weights into sharper minima. The outpost model shows properties of network stability arising from a decay of learning rate across the entire network. The high variance of validation accuracy shows that keeping the learning rate at 0.1 for other blocks contributes to epoch-wise instability of our learned network.
Experiment B: Network Stability
For this experiment, we maintain a learning rate of 0.1 for any blocks that are training, without decay. This way, we avoid obfuscating our results with the effects of settling into sharper minima.
To simulate the effects of network stability, we will freeze all filter blocks except one in each experiment. This should help enforce the property of network stability we mentioned earlier,
We will train the following 5 models:
- Baseline model with no LR decay
- Filter block 64 with LR 0.1, other blocks are frozen at epoch 60
- Filter block 128 with LR 0.1, other blocks are frozen at epoch 60
- Filter block 256 with LR 0.1, other blocks are frozen at epoch 60
- Filter block 512 with LR 0.1, other blocks are frozen at epoch 60
By comparing the effects of at different depths of the network, we can establish that the stability effect is a network-wide effect. It should not be better for deeper blocks only or for earlier blocks only.
Our hypothesis is that each block which is continually trained should reap the benefits of network stability and converge to a better minima in the following landscape
where j to k are indices of weights in a given filter block.
Our expectation is that each experiment should achieve better results than the baseline. Moreover, the resultant networks should have a more stable solution with a lower variance of training/validation loss and accuracy across epochs of stochastic gradient descent.
The results are as follows:
Zooming in on the Top 1 Validation Accuracy graph,
By simply freezing the rest of the network at epoch 60, we see a clear boost in validation accuracy.
Without confounding our results with the effects of sharper minima, we establish here that network stability is an important factor in convergence.
Moreover, our validation accuracy on these models is higher than the accuracy of Experiment A models. This suggests that stability may play a bigger role in helping networks converge than the sharp minima effect.
Experiment C: Combining Sharp Minima & Network Stability
In this experiment, we would like to confirm that in the absence of network wide LR decay, we are able to combine effects of sharp minima and network stability. We will do this by decaying only the learning rate of a given block at epoch 60, then freezing the rest of the network at epoch 90. We will train the following 6 models:
- Baseline model with no LR decay
- Outpost model with LR decay
- Filter block 64 with LR decay, other blocks are frozen at epoch 90
- Filter block 128 with LR decay, other blocks are frozen at epoch 90
- Filter block 256 with LR decay, other blocks are frozen at epoch 90
- Filter block 512 with LR decay, other blocks are frozen at epoch 90
The results are as follows
Zooming into the Top 1 Train Accuracy graph and Top 1 Validation Accuracy graph,
By combining the effects of LR decay on a single block and then freezing the rest of the network, we replicate characteristics of the Outpost Network which has network-wide LR decay. As expected, the network-wide LR decay model still performs better than our filter block models. However, our combination approach is able to mimic the effects of LR decay quite well, in terms of stability, validation, and training accuracy/loss.
When block freezing is applied at epoch 90 ( 3.5 * 10⁴ iterations), we enable deep filter block models 256 and 512 to converge to 100% training accuracy. The models where earlier blocks are continued to be trained, like 64 & 128, do not reach the same level of training accuracy and they beat the other models on the validation set top 1 accuracy. This gives us some evidence that overfitting occurs more so in deeper layers as training progresses.
If we take a look at the validation loss, this also shows us an interesting relationship between network depth and overfitting.
After decaying LR for the Outpost model at epoch 60, validation loss is at the lowest. As training progresses beyond epoch 60, the loss on the training set decreases, but contrary to what we would expect, the validation loss increases for the network-wide decay model while the validation top 1 accuracy also increases.
We use cross-entropy loss to train our models, which means that for the loss to increase, we must observe one of the two following cases:
- Less confidence of correct predictions
- More confidence of incorrect predictions
Since training loss continually decreases while training top 1 accuracy is already at 100%, this suggests that the model is minimising the training loss by outputting more confident predictions.
We would think that the confident predictions are overfit, and do not generalise well to the validation set. However, the validation top 1 accuracy increases over training, but we notice that the validation top 5 accuracy decreases. This means as training progresses with network wide LR decay, when the model is wrong, it has more trouble finding the true label in the top 5 labels predicted. This can be attributed to the boosted confidence of other labels, and so the validation loss increases.
Taking a look at the block freezing models, the filter blocks where deeper layers are continually trained, like in 256 or 512, we observe the same overfitting effect as the outpost model. As training progresses, these models increase their validation loss and converge to a similar validation loss as the network wide LR decay model.
However, the filter block models where earlier layers are trained and deeper layers are frozen, like in model 64 or 128, we observe that validation loss does not increase over training in the same way. The top 5 validation accuracy also does not suffer.
When we freeze early layers, deep layers that are kept training are able to better overfit to stable signals from early subnetworks. This shows how a high learning rate is actually a form of regularisation by adding noise to the stochastic process. By setting the rate at which weights shift, we define the diversity of the conditional probability cloud we move our loss through. From our experiments, we see that early layers that are kept training while deep layers are frozen have a much tougher time in overfitting to the intricacies of the training set. They improve the most on the validation set loss.
When we freeze deep layers, early layers improve most on validation loss and accuracy immediately. This shows how important it is for early layers to have a stable image of the back-propagated loss to converge to better features.
Both of these observations support the two hypotheses I detailed in the abstract about the relationship between early layers, deep layers, and stability.
Experiment D: No Compounding Effects
For our last experiment, I’d like to show that once we have applied LR decay to the entire network, stability is no longer realisable by freezing blocks. This will show that LR decay creates these stability effects, and further freezing actually has no beneficial effect.
We will train the following 6 models:
- Outpost model with LR decay
- Filter block 64 with LR decay across network, other blocks are frozen at epoch 140
- Filter block 128 with LR decay across network, other blocks are frozen at epoch 140
- Filter block 256 with LR decay across network, other blocks are frozen at epoch 140
- Filter block 512 with LR decay across network, other blocks are frozen at epoch 140
At epoch 140, we apply block freezing after we apply network wide LR decay. If network stability is an effect of LR decay, the gains from block freezing should be no longer realisable.
The results are as follows:
As expected, we see no extra gains at epoch 140 across our models. Thus, we effectively show that the effects of network stability are indeed introduced by learning rate decay.
Conclusion
We find there exists two effects of learning rate decay:
- Reach sharper minima in the optimisation landscape
- Stability of forward propagated signals & backward propagated signals
Instability arises from constant change. Gradient descent updates all weights for each batch, effectively adding noise to forward and back propagated signals. By freezing layers and improving accuracy, we show the merit of network stability as a positive effect of learning rate decay.
In this view, a higher learning rate is a form of regularisation. Decreasing the learning rate is akin to lowering the temperature of a simulated annealing optimiser. This affects the probabilistic loss landscape. In the same way that Tikhonov regularization affects the prior probability to favour smaller weights, the learning rate we choose applies a prior probability which affects the discoverable minima in the stochastic optimisation process.
Citations
[1] Smith, Leslie N. “Cyclical learning rates for training neural networks.” 2017 IEEE Winter Conference on Applications of Computer Vision (WACV) . IEEE, 2017.
[2] Dauphin, Y. N., et al. “RMSProp and equilibrated adaptive learning rates for non-convex optimization. arXiv 2015.” arXiv preprint arXiv:1502.04390 .
[3] Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning . MIT press, 2016.
[4] Li, Yong & Zeng, Jiabei & Zhang, Jie & Dai, Anbo & Kan, Meina & Shan, Shiguang & Chen, Xilin. (2017). KinNet: Fine-to-Coarse Deep Metric Learning for Kinship Verification. 13–20. 10.1145/3134421.3134425.
Written by
Deep learning researcher. Builder of software. Follow me on twitter @nofreeshivam :) full bio @ www.shivam.sh
A Medium publication sharing concepts, ideas, and codes.
Written by
Deep learning researcher. Builder of software. Follow me on twitter @nofreeshivam :) full bio @ www.shivam.sh
A Medium publication sharing concepts, ideas, and codes.
Welcome to a place where words matter. OnMedium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories onMedium— and support writers while you’re at it. Just $5/month. Upgrade
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
现代前端技术解析
张成文 / 电子工业出版社 / 2017-4-1 / 79.00元
这是一本以现代前端技术思想与理论为主要内容的书。前端技术发展迅速,涉及的技术点很多,我们往往需要阅读很多书籍才能理解前端技术的知识体系。《现代前端技术解析》在前端知识体系上做了很好的总结和梳理,涵盖了现代前端技术绝大部分的知识内容,起到一个启蒙作用,能帮助读者快速把握前端技术的整个脉络,培养更完善的体系化思维,掌握更多灵活的前端代码架构方法,使读者获得成为高级前端工程师或架构师所必须具备的思维和能......一起来看看 《现代前端技术解析》 这本书的介绍吧!