From PyTorch to PyTorch Lightning — A gentle introduction
Feb 27 ·10min read
This post answers the most frequent question about why you need Lightning if you’re using PyTorch.
PyTorch is extremely easy to use to build complex AI models. But once the research gets complicated and things like multi-GPU training, 16-bit precision and TPU training get mixed in, users are likely to introduce bugs.
PyTorch Lightning solves exactly this problem. Lightning structures your PyTorch code so it can abstract the details of training. This makes AI research scalable and fast to iterate on.
Who is PyTorch Lightning For?
PyTorch Lightning was created for professional researchers and PhD students working on AI research.
Lightning was born out of my Ph.D. AI research at NYU CILVR and Facebook AI Research . As a result, the framework is designed to be extremely extensible while making state of the art AI research techniques (like TPU training) trivial.
Now the core contributors are all pushing the state of the art in AI using Lightning and continue to add new cool features.
However, the simple interface gives professional production teams and newcomers access to the latest state of the art techniques developed by the Pytorch and PyTorch Lightning community.
Lightning counts with over 96 contributors , a core team of 8 research scientists , PhD students and professional deep learning engineers.
it is rigorously tested
Outline
This tutorial will walk you through building a simple MNIST classifier showing PyTorch and PyTorch Lightning code side-by-side. While Lightning can build any arbitrarily complicated system, we use MNIST to illustrate how to refactor PyTorch code into PyTorch Lightning.
The Typical AI Research project
In a research project, we normally want to identify the following key components:
- the model(s)
- the data
- the loss
- the optimizer(s)
The Model
Let’s design a 3-layer fully-connected neural network that takes as input an image that is 28x28 and outputs a probability distribution over 10 possible labels.
First, let’s define the model in PyTorch
This model defines the computational graph to take as input an MNIST image and convert it to a probability distribution over 10 classes for digits 0–9.
To convert this model to PyTorch Lightning we simply replace the nn.Module with the pl.LightningModule
The new PyTorch Lightning class is EXACTLY the same as the PyTorch, except that the LightningModule provides a structure for the research code.
Lightning provides structure to PyTorch code
See? The code is EXACTLY the same for both!
This means you can use a LightningModule exactly as you would a PyTorch module such as prediction
Or use it as a pretrained model
The Data
For this tutorial we’re using MNIST.
Let’s generate three splits of MNIST, a training, validation and test split.
This again, is the same code in PyTorch as it is in Lightning.
The dataset is added to the Dataloader which handles the loading, shuffling and batching of the dataset.
In short, data preparation has 4 steps:
- Download images
- Image transforms (these are highly subjective).
- Generate training, validation and test dataset splits.
- Wrap each dataset split in a DataLoader
Again, the code is exactly the same except that we’ve organized the PyTorch code into 4 functions:
prepare_data
This function handles downloads and any data processing. This function makes sure that when you use multiple GPUs you don’t download multiple datasets or apply double manipulations to the data.
This is because each GPU will execute the same PyTorch thereby causing duplication. ALL of the code in Lightning makes sure the critical parts are called from ONLY one GPU.
train_dataloader, val_dataloader, test_dataloader
Each of these is responsible for returning the appropriate data split. Lightning structures it this way so that it is VERY clear HOW the data are being manipulated. If you ever read random github code written in PyTorch it’s nearly impossible to see how they manipulate their data.
Lightning even allows multiple dataloaders for testing or validating.
The Optimizer
Now we choose how we’re going to do the optimization. We’ll use Adam instead of SGD because it is a good default in most DL research.
Again, this is exactly the same in both except it is organized into the configure optimizers function.
Lightning is extremely extensible. For instance, if you wanted to use multiple optimizers (ie: a GAN), you could just return both here.
You’ll also notice that in Lightning we pass in self.parameters() and not a model because the LightningModule IS the model.
The Loss
For n-way classification we want to compute the cross-entropy loss. Cross-entropy is the same as NegativeLogLikelihood(log_softmax) which we’ll use instead.
Again… code is exactly the same!
Training and Validation Loop
We assembled all the key ingredients needed for training:
- The model (3-layer NN)
- The dataset (MNIST)
- An optimizer
- A loss
Now we implement a full training routine which does the following:
- Iterates for many epochs (an epoch is a full pass through the dataset D )
- Each epoch iterates the dataset in small chunks called batches b
- We perform a forward pass
- Compute the loss
- Perform a backward pass to calculate all the gradients for each weight
- Apply the gradients to each weight
In both PyTorch and Lightning the pseudocode looks like this
This is where lightning differs though. In PyTorch, you write the for loop yourself which means you have to remember to call the correct things in the right order — this leaves a lot of room for bugs.
Even if your model is simple, it won’t be once you start doing more advanced things like using multiple GPUs, gradient clipping, early stopping, checkpointing, TPU training, 16-bit precision, etc… Your code complexity will quickly explode.
Even if your model is simple, it won’t be once you start doing more advanced things
Here’s are the validation and training loop for both PyTorch and Lightning
This is the beauty of lightning. It abstracts the boilerplate (the stuff not in boxes) but leaves everything else unchanged. This means you are STILL writing PyTorch except your code has been structured nicely.
This increases readability which helps with reproducibility!
The Lightning Trainer
The trainer is how we abstract the boilerplate code.
Again, this is possible because ALL you had to do was organize your PyTorch code into a LightningModule
Full Training Loop for PyTorch
The full MNIST example written in PyTorch is as follows:
Full Training loop in Lightning
The lightning version is EXACTLY the same except:
- The core ingredients have been organized by the LightningModule
- The training/validation loop code has been abstracted by the Trainer
Highlights
Let’s call out a few key points
- Without Lightning, the PyTorch code is allowed to be in arbitrary parts. With Lightning, this is structured.
- It is the same exact code for both except that it’s structured in Lightning. (worth saying twice lol).
- As the project grows in complexity, your code won’t because Lightning abstracts out most of it.
- You retain the flexibility of PyTorch because you have full control over the key points in training. For instance, you could have an arbitrarily complex training_step such as a seq2seq
5. In Lightning you got a bunch of freebies such as a sick progress bar
you also got a beautiful weights summary
tensorboard logs (yup! you had to nothing to get this)
and free checkpointing, and early stopping.
All for free!
Additional Features
But Lightning is known best for out of the box goodies such as TPU training etc…
In Lightning, you can train your model on CPUs, GPUs, Multiple GPUs, or TPUs without changing a single line of your PyTorch code.
You can also do 16-bit precision training
Log using 5 other alternatives to Tensorboard
We even have a built in profiler that can tell you where the bottlenecks are in your training.
Setting this flag on gives you this output
Or a more advanced output if you want
We can also train on multiple GPUs at once without you doing any work (you still have to submit a SLURM job)
And there are about 40 other features it supports which you can read about in the documentation.
Extensibility With Hooks
You’re probably wondering how it’s possible for Lightning to do this for you and yet somehow make it so that you have full control over everything?
Unlike keras or other high-level frameworks lightning does not hide any of the necessary details. But if you do find the need to modify every aspect of training on your own, then you have two main options.
The first is extensibility by overriding hooks. Here’s a non-exhaustive list:
- forward pass
- backward pass
- applying optimizers
- anything you would need to configure
These overrides happen in the LightningModule
Extensibility with Callbacks
A callback is a piece of code that you’d like to be executed at various parts of training. In Lightning callbacks are reserved for non-essential code such as logging or something not related to research code. This keeps the research code super clean and organized.
Let’s say you wanted to print something or save something at various parts of training. Here’s how the callback would look like
Now you pass this into the trainer and this code will be called at arbitrary times
This paradigm keeps your research code organized into three different buckets
- Research code (LightningModule) (this is the science).
- Engineering code (Trainer)
- Non-research related code (Callbacks)
How to start
Hopefully this guide showed you exactly how to get started. The easiest way to start is to run the colab notebook with the MNIST example here .
Or install Lightning
Or check out the Github page .
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。