内容简介:Disclaimer:I have added all code as pictures, as it displays better on smartphones.I leave it to you to figure out how to implement these. If you have any questions, feel free to ask in the comments section.And before I go, there is one great new coming up
Writing TensorFlow 2 Custom Loops
A step-by-step guide from Keras to TensorFlow 2
Apr 30 ·9min read
I have been a Keras user for quite a while, and I never really found a good reason to move elsewhere. TensorFlow 1 was clumsy, and PyTorch, although sexy, didn’t quite get me. I tend to learn the most when doing prototypes, and Keras is the king here. Fast forward to September 2019, TensorFlow 2 was released. Soon after, I switched all import keras
calls to import tensorflow.keras
.
If this article was just about migrating from Keras to TensorFlow 2, it could end here. Really. However, among all things that came with TensorFlow 2, custom training loops are hands down the best new feature for Keras users. In this article, I explain why they matter and how to implement it.
Disclaimer:I have added all code as pictures, as it displays better on smartphones. All code is available as a GitHub gist. Write to me if you know or prefer some other alternative to posting code.
TensorFlow 2
For the uninitiated, the thing with TensorFlow 2 is that it threw away most TensorFlow 1 idioms and several of its APIs over Keras, the one and only API for defining neural networks. Moreover, they embraced eager execution.
Together, these two changes tackle many of the early TensorFlow criticism:
- There are no longer several APIs.
- There is no need to manage variables manually.
- Sessions are gone.
- You can debug the execution flow.
- Dynamic models are now possible.
At the same time, they managed to pull that off while retaining the deployment capabilities and speed that made TensorFlow 1 famous.
For Keras users, the new API means four things:
- We are no longer within an abstraction layer.
- Authoring new techniques is considerably easier.
- Writing custom training loops is now practical.
- Execution is considerably faster.
Among all things, custom loops are the reason why TensorFlow 2 is such a big deal for Keras users. Custom loops provide ultimate control over training while making it about 30% faster.
To be honest, a better name for TensorFlow 2 would be Keras 3.
I know some google developer is probably dying on the inside with me writing this, but it is the truth. TensorFlow 1 was the backend for Keras, and it mostly remained that way with the upgrade. Most new things are way more Keras-y than TensorFlow-ish.
The Training Loop
In essence, all Keras programs adhere to the following structure:
It all begins with imports, then dataset loading and preparation, model loading/creation, and, finally, the compile and fit calls. Unique to Keras, the compile method associates the model to a loss function and an optimizer, and the fit function performs the so-called “training loop.”
The training loop is the code that feeds the entire training set, batch-by-batch, to the algorithm, computing the loss, its gradients, and applying the optimizer. Then, the validation set is fed to calculate the validation loss and validation metrics. Finally, this whole process is repeated for several “epochs.”
During the training loop, other functionalities are also performed by the fit method, such as manipulating several worker threads, checkpointing the model, and logging results to disk. Custom code can also be inserted at specific events via callbacks.
Keras users are used to not having to write any training loop at all. The fit function does it all perfectly and allows the right amount of customization for most use cases
The problem of fixed training loops is that when a different use case arises, you are out of luck. The original Keras provided little in the way of custom training loops.
Examples can easily be found on GAN tutorials. To train adversarial networks, you have to interleave their training. For that, the train_on_batch
method was the best approach to interleave the generator and discriminator training, leaving you to batch and shuffle the dataset by hand, write your progress bar code, and wrap it all into a for-loop for each epoch.
In 2017, the Wasserstein GAN was proposed, which requires a gradient clipping to be added to the training, a simple requirement. WGAN implementations in Keras are, to keep it civilized, bloated.
These are not isolated examples. Any multi-network setup, gradient trick, or out-of-the-box solution likely requires you to write a lot of code. This is why Keras is so unpopular among researches (and also why PyTorch is so popular).
The Custom Loop
What TensorFlow 2 brought to the table for Keras users is the power to open-up the train_on_batch
call, exposing the loss, gradient, and optimizer calls. However, to use it, you have to let go of the compile and fit functionalities.
On the bright side, Keras is no longer an abstraction over TensorFlow. It is part of it now. This means that all the weird stuff we had to do to create custom logic in Keras is no longer needed. Everything is compatible. We no longer have to import keras.backend as K
.
First, you have to get the loss and optimizer objects yourself. Then, you define the train_on_batch call. This looks a lot like PyTorch code to me. You simply call the method passing X to get ŷ, then you compare it against y, getting the loss value.
This is all happening within the context of a gradient tape, which is just a way to track which operations should be differentiated. Using the tape, we can get the gradients concerning the loss for each training variable. Then, the gradient-variable pairs are fed to the optimizer, which will update the network.
The last line is just an example of how to extract a batch of 128 samples to debug our new method.
Its twin, the validate_on_batch
, is just a simpler version of this. We just have to get rid of the tape and gradient logic.
In essence, it is the same code we find within the gradient tape, but without the tape (and the gradient/optimizer logic).
A note on the “training=True” and “training=False” parameters:This changes how some layers behave. For instance, dropout layers drop some connections during training, but not during testing. This does not affect the optimization directly or any other training-related task.
A minimal example of a full training loop can be seen below:
At this point, we have a minimal working example. However, it lacks several essential features:
- The current batching logic is clumsy and error-prone.
- There is no progress indicator beside one print per epoch.
- There is no model checkpoint logic.
This brings us to the following point:
Doing your own loop is nice, but it also requires you to re-code some of the features you took for granted
Thankfully, it is not that hard to improve it :)
Improving the Loop
The first thing to note is the tf.data
package. It contains the tf.data.Dataset class, which encapsulates several dataset tasks.
The dataset API uses a “fluent style.” This means that all calls to the Dataset object return another Dataset object. This makes chaining calls possible. Here is an example of using it for our problem:
These calls create a dataset object for the training and testing data, prepares it for shuffling, and batches the instances. In our loop, the clumsy batching code becomes the following:
Much neater. The new code only has to enumerate all batches from the dataset. This also handles shuffling for us, which is always a good idea.
To make our code look a bit more alive, we can add a couple of print statements. We can also make our validation logic a bit shorter:
Now we print the current epoch and batch, along with some dots to keep it moving. For the validation, we create a list of the average loss per batch, followed by a print of the final average loss. This is how it looks in the console:
For the model checkpoint, we have to track the best validation accuracy so far, so that we only save models that improved our performance. This requires us to track the best loss and to compare the new loss against it. Code-wise, this becomes the following:
This completes our goal of improving dataset handling, seeing more frequent screen updates, and saving model checkpoints as it trains.
The final touch is to make it performant.
The tf.function
So far, we are running on the so-called “eager mode.” This means that if you write “a + b”, the result of the sum is immediately computed for you. While this eases debugging, it is not the most performant approach.
Deep learning happens at the GPU. Each “sum” and “multiply” command has a cost. The CPU has to call the GPU and tell which variables need to be operated, wait for the operation to complete, and fetch back results. This is slow . A faster approach is to give the GPU a massive list of things to do and wait only once.
This alternative approach is referred to as “deferred mode” or “computational graph.” The idea is to let TensorFlow turn your network into a set of mathematical steps that operate over data. This command list is then sent to the GPU and processed as a whole, which is much faster.
Creating a command list is also crucial for optimization. Models such as InceptionNet have multiple paths, which can be computed in parallel. Simple operations can be fused, such as a multiply followed by an add, and so on.
To allow TensorFlow to build this graph for you, you only need to annotate the train_on_batch
and validate_on_batch
calls with the @tf.function annotation. Simple as that:
The first time both functions are called, TensorFlow will parse its code and build the associated graph. This will take a bit longer than usual but will make all subsequent calls considerably faster.
An excellent way to see this in action is to put a print statement inside. The print will only be executed once or twice during the computational graph construction. Then, it won’t print anymore, as the function is not being called anymore.
Numerically, using an RTX 2070 GPU, the original Keras fit function takes 18 seconds, the custom loop takes 40 and the optimized loop takes 20.This simple annotation made it twice as fast as the eager mode. Compared to the Keras fit, it is 2 seconds slower, showing how well optimized is the original fit is. For larger problems and networks, the optimized custom loop surpasses the original fit. In practice, I have seen up to 30% faster epochs using custom loops.
The downside of using the @tf.function annotation is that its error messages are terrible. The rule of thumb is to develop without it first, then add it just for the validation, and then to the training. This way, you can pin-down bugs easier.
Next Steps
Users that made the switch to custom loops are continually implementing improvements to their workflow. I, for instance, have implemented callbacks and metrics in my custom code, as well as a nice progress bar. Over the internet, you can find some packages, such as TQDM or the very Keras progress bar .
Implementing the custom loop looks like extra effort, but it quickly pays off, and you only have to do it once. It allows you to have full control of when validation happens, which metrics are computed, complex training schedules, and so on. Model-wise, it is much easier to tamper with the training process. You can add gradient penalties, train several models, or create virtual batches with ease.
In fact, that’s what PyTorch users have been doing for years.
I leave it to you to figure out how to implement these. If you have any questions, feel free to ask in the comments section.
And before I go, there is one great new coming up soon. T ensorflow 2.2 will allow you to feed your own train_on_batch
and validate_on_batch
functions to the original .fit API . This means we will have the best of both worlds. The fit call will be more modular while we retain the possibility of implementing it all from scratch.
Thanks for reading :)
The code for this article can be found here .
以上所述就是小编给大家介绍的《Writing TensorFlow 2 Custom Loops》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
大型网站系统与Java中间件开发实践
曾宪杰 / 电子工业出版社 / 2014-4-24 / 65.00
本书围绕大型网站和支撑大型网站架构的 Java 中间件的实践展开介绍。从分布式系统的知识切入,让读者对分布式系统有基本的了解;然后介绍大型网站随着数据量、访问量增长而发生的架构变迁;接着讲述构建 Java 中间件的相关知识;之后的几章都是根据笔者的经验来介绍支撑大型网站架构的 Java 中间件系统的设计和实践。希望读者通过本书可以了解大型网站架构变迁过程中的较为通用的问题和解法,并了解构建支撑大型......一起来看看 《大型网站系统与Java中间件开发实践》 这本书的介绍吧!