LSTM Text Classification Using Pytorch

栏目: IT技术 · 发布时间: 4年前

LSTM Text Classification Using Pytorch

A step-by-step guide teaching you how to build a bidirectional LSTM in Pytorch!

LSTM Text Classification Using Pytorch

Photo by Christopher Gower on Unsplash

Intro

Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification , this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.

First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network , which belongs to a larger category of neural networks called Recurrent Neural Network (RNN) . Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.

LSTM Text Classification Using Pytorch

LSTM Cell

Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:

  1. Preprocess Dataset
  2. Importing Libraries
  3. Load Dataset
  4. Build Model
  5. Training
  6. Evaluation

Before we dive right into the tutorial, here is where you can access the code in this article:

Step 1: Preprocess Dataset

The raw dataset looks like the following:

LSTM Text Classification Using Pytorch

Dataset Overview

The dataset contains an arbitrary index, title, text, and the corresponding label.

For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string function which will be used to cut each sentence to the first first_n_words words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words , and split the dataset according to train_test_ratio and train_valid_ratio . We save the resulting dataframes into .csv files, getting train.csv , valid.csv , and test.csv .

Step 2: Importing Libraries

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

Step 3: Load Dataset

First, we use torchText to create a label field for the label in our dataset and a text field for the title , text , and titletext . We then build a TabularDataset by pointing it to the path containing the train.csv , valid.csv , and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).

Step 4: Build Model

We construct the LSTM class that inherits from the nn.Module . Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).

Step 5: Training

Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.

We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:

LSTM Text Classification Using Pytorch

The whole training process was fast on Google Colab. It took less than two minutes to train!

Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.

LSTM Text Classification Using Pytorch

Step 6: Evaluation

Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.

LSTM Text Classification Using Pytorch

We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.

Conclusion

This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification !

If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)

References

[1] S. Hochreiter, J. Schmidhuber, Long Short-Term Memory (1997), Neural Computation


以上所述就是小编给大家介绍的《LSTM Text Classification Using Pytorch》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

程序员的自我修养

程序员的自我修养

陈逸鹤 / 清华大学出版社 / 2017-5 / 49.00

程序员作为一个职业、也作为一个群体,正逐渐从幕后走向前台,并以他们自己的能力加速改变着世界,也改变着人们生活的方方面面。然而,对于程序员,特别是年轻程序员们来说,如何理解自己的职业与发展,如何看待自己的工作与生活,这些问题往往比那些摆在面前的技术难题更让他们难以解答。 这本书从一个成熟程序员、一名IT管理者的角度,以杂记的形式为大家分享关于国内程序员职业生涯、个人发展、编程中的实践与认知乃至......一起来看看 《程序员的自我修养》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

URL 编码/解码
URL 编码/解码

URL 编码/解码

SHA 加密
SHA 加密

SHA 加密工具