Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

栏目: IT技术 · 发布时间: 4年前

Learn AI Today

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Classifying flowers with Neural Networks, visualizing decision boundaries and understanding overfitting.

Jul 27 ·9min read

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Adapted from photo by Kelly Sikkema on Unsplash .

This is the second story in the Learn AI Today series I’m creating! These stories, or at least the first few, are based on a series of Jupyter notebooks I’ve created while studying/learning PyTorch and Deep Learning . I hope you find them as useful as I did!

If you have not already, make sure to check the previous story!

What you will learn in this story:

  • The Importance of Validation
  • How to Train Models for Classification Problems
  • Visualize the Decision Boundaries Dynamically
  • How to Avoid Overfitting

1. Iris Flower Dataset

Let’s get started by introducing the dataset. I will be using the very famous Iris flower dataset that contains 4 different measurements (sepal length, sepal width, petal length, petal width) of the following 3 species of flowers.

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Image adapted from Wikipedia .

The goal is to accurately identify the species using the 4 measurements for each flower.Note that nowadays it’s relatively easy to use a model (Convolutional Neural Networks) that learns directly from the images but I will leave that topic for the next lesson. The Iris flower dataset can be easily downloaded from sklearn datasets as shown in the code below.

To get a quick visualization of the data let’s plot the scatter plots of each pair of features and the histograms for each feature. To achieve this representation I used the pandas.plotting.scatter_matrix function (as always you can find the link to the full code at the end).

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Visualization of Iris flower data (blue dots — iris setosa; green-bluish dots — iris versicolor; yellow dots — iris virginica). Image by the author.

As you can see on the scatter plots above, for most of the samples it should be easy to discriminate the specie. For example, the Iris setosa dots are, in most plots, very well separated from the other two species. For such an easy example you can easily create a rule-based algorithm by drawing a few lines. However, for its simplicity, is also a good example to introduce classification with Neural Networks!

2. Validation Set

Before jumping straight to the models and training it’s very important to create a validation set . I skipped this step in the first lesson to avoid introducing too many concepts all at once.

The idea of a validation set is simple. Instead of training a model with all your data you put apart a fraction of the data (usually 20% — 30%) that you will use to evaluate if the trained model generalizes well to unseen data . This is very important to make sure your model can be safely put into production to evaluate new data accurately.

In the code above I make use of train_test_split function to randomly split the data and I choose a test_size=0.5 . It’s always a good idea to set the random_state to make sure that when you re-run the code the same split will be used.

Notice that it is common and a good practice two have not 2 but 3 data splits: train, validation and test . In that case, you use the validation set to check the progress when you try several models, ideas and hyper-parameters (e.g. the learning rate) and you only use the test set at the end when you are happy with the results. This is what happens in Kaggle competitions where usually there is a hidden test set .

3. Training a Model for Classification

The model I’m going to use for this example is exactly the same I used in theprevious lesson for the regression problems!

So what’s the difference?The difference is in the loss function . For multi-class classification problems, the usual choice is Cross Entropy Loss (nn.CrossEntropyLoss in PyTorch). For binary classification problems, you usually use Binary Cross Entropy Loss (nn.BCEWithLogitsLoss). As a result, the code for defining the model, criterion and optimizer is very similar to what I used in the previous lesson for regression!

An additional difference to consider is the last activation function. For regression problems, the output of the model is a number that can be any real value. For binary classification , you need to use a Sigmoid activation function that maps the output to the 0–1 range. For multi-class classification you need Softmax activation function (unless you want to allow for multiple choices, in that case use Sigmoid activation). The Softmax output can be interpreted as the probability assigned to each class.

Don’t worry too much for now about the activation functions. I’m just mentioning it here so that you are aware of their existence. For now, the good thing to know is that nn.CrossEntropyLoss includes the Softmax activation for you and the nn.BCEWithLogitsLoss includes the Sigmoid for you. That way you don’t need to add any activation function at the end of the model. PyTorch takes care of that for you!

Before training the model I also changed the fit function from the previous lesson (you can check it in the Kaggle notebook with the complete code at the end) to allow for train and test/validation data. (When working with only two datasets the terms validation and test are often used interchangeably.)

Plotting the losses for train and test during the 1000 epochs of training you can see something weird is going on .

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Train and test losses. Image by the author.

While the train loss keeps improving over time, the test loss initially improves steadily but then starts to increase . And perhaps most importantly, you can see the same effect by plotting the accuracy of the model over the 1000 epochs.

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Train and test accuracy. Image by the author.

What you are seeing is what we call overfitting — at some point the model starts ‘remembering’ the train data and the generalization performance (evaluated with the test set) decreases . And that’s why you need a validation/test set. In this case you can see that possibly the results are better if you stop the training at 200 epochs. This method is called early stopping and is an easy way to reduce overfitting . However there are other ways such as using weight decay that I will talk about in a minute. Let’s first do some fun visualizations to understand better what is going.

4. Visualizing Decision Boundaries and Reducing Overfitting

To visualize the training boundaries and better understand overfitting, I retrained the model using only 2 (instead of 4) features in order to easily plot the result in a 2D graphic. In the animation below you can see the decision boundaries for train (left) and test (right) evolving during the 1000 training epochs.

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Visualizing model decision boundaries during 1000 epochs of training. Animation by the author.

Look how initially the model quickly learns 2 straight line boundaries dividing the regions. Then, the red-yellow boundary starts to curve upwards , adjusting better to the train set. However, looking at the test set this leads to a slight decrease in performance as more yellow dots move to the middle region and green-bluish dots to the top region. And this is exactly what overfitting looks like . For more complex problems you can imagine this in multiple dimensions — it’s easier to overfit when you have a lot of dimensions.

If you have trouble imagining anything in more than 3 dimensions just follow this advice from Geoffrey Hinton: “To deal with hyper-planes in a 14-dimensional space, visualize a 3-D space and say ‘fourteen’ to yourself very loudly. Everyone does it.”

To reduce overfitting in this example you can use two simple ‘tricks’:

  • Early Stopping: Train the model for fewer epochs
  • Weight Decay: Force the model weights to be small by reducing them by a small amount each iteration

4.1 Early Stopping

The idea of Early Stopping is very simple, as I mentioned before if the model has the best validation accuracy at around epoch 200 then if you train for only 200 epochs you will get a model that generalizes better — according to the validation accuracy. The problem is that you may then be visually overfitting to the validation set — particularly if you tune a lot of hyperparameters based on the validation score. This is why it’s often important to have an additional set of data to evaluate the model after finishing all the experiments.

4.2 Weight Decay

The idea of Weight Decay is also simple. When fitting a Neural Network, in general, there isn’t an optimum solution but multiple possible similar solutions. Weight Decay , by forcing the weights to stay small, will force the optimization process to reach a simpler solution.

Let’s add a weight_decay=0.01 to our model and visualize the results after training for 1000 epochs as before. In PyTorch, you just need to add this parameter to the optimizer as optimizer = optim.Adam(model.parameters, lr=0.001, weight_decay=0.01) . The resulting animation is the following.

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Visualizing model decision boundaries during 1000 epochs of training with weight decay of 0.01. Animation by the author.

As you can see, now the red-yellow boundary does not curve upwards as before since it would require a strong increase in the magnitude of the weights and the accuracy wouldn’t change much.

Training the model with the 4 input features with weight decay results in the following plot for the train and test losses. Look that now the test loss does not start increasing as before!

Learn AI Today: 02 — Introduction to Classification Problems using PyTorch

Train and test losses for a model trained with weight decay. Image by the author.

It is also important to mention that when you are starting a new project, overfitting is something you should aim for . Start with a model that can overfit the data so that you know that your model has enough ‘flexibility’ to learn the patterns in your data. Then increase regularization, like weight decay, to avoid overfitting!

Homework

I can show you a thousand examples but you will learn more if you can make one or two experiments by yourself! The complete code for these experiments is available on this notebook .

  • As in the previous lesson, try to play with the learning rate, number of epochs, weight decay and the size of the model.
  • Make experiments and look if the results are what you expected, if not look at the visualizations and try to understand why.

And as always, if you create interesting notebooks with nice animations as a result of your experiments, go ahead and share them on GitHub, Kaggle or write a Medium story!

Final remarks

This ends the second story in the Learn AI Today series!

Feel free to give me some feedback in the comments. What did you find most useful or what could be explained better? Let me know!

You can read more about my journey on the following stories!

Thanks for reading! Have a great day!


以上所述就是小编给大家介绍的《Learn AI Today: 02 — Introduction to Classification Problems using PyTorch》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

群智能优化算法及其应用

群智能优化算法及其应用

雷秀娟 / 2012-8 / 85.00元

《群智能优化算法及其应用》编著者雷秀娟。 《群智能优化算法及其应用》内容提要:本书以群智能优化算法中的粒子群优化(]Particle Swarm Optimization,PSO)算法为主线,着重阐述了PSO算法的基本原理、改进策略,从解空间设计、粒子编码以及求解流程等方面进行了详细设计与阐述,对蚁群优化(Ant Colony Optimization,AC0)算法、人工鱼群(Art......一起来看看 《群智能优化算法及其应用》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具