LSTM Text Classification Using Pytorch
A step-by-step guide teaching you how to build a bidirectional LSTM in Pytorch!
Jun 30 ·5min read
Intro
Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification , this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.
First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network , which belongs to a larger category of neural networks called Recurrent Neural Network (RNN) . Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.
Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:
- Preprocess Dataset
- Importing Libraries
- Load Dataset
- Build Model
- Training
- Evaluation
Before we dive right into the tutorial, here is where you can access the code in this article:
Step 1: Preprocess Dataset
The raw dataset looks like the following:
The dataset contains an arbitrary index, title, text, and the corresponding label.
For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string
function which will be used to cut each sentence to the first first_n_words
words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.
Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words
, and split the dataset according to train_test_ratio
and train_valid_ratio
. We save the resulting dataframes into .csv files, getting train.csv , valid.csv , and test.csv .
Step 2: Importing Libraries
We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.
Step 3: Load Dataset
First, we use torchText to create a label field for the label in our dataset and a text field for the title , text , and titletext . We then build a TabularDataset by pointing it to the path containing the train.csv , valid.csv , and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).
Step 4: Build Model
We construct the LSTM class that inherits from the nn.Module . Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).
Step 5: Training
Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.
We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:
The whole training process was fast on Google Colab. It took less than two minutes to train!
Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.
Step 6: Evaluation
Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.
We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.
Conclusion
This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification !
If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)
References
[1] S. Hochreiter, J. Schmidhuber, Long Short-Term Memory (1997), Neural Computation
以上所述就是小编给大家介绍的《LSTM Text Classification Using Pytorch》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。