LSTM Text Classification Using Pytorch

栏目: IT技术 · 发布时间: 4年前

LSTM Text Classification Using Pytorch

A step-by-step guide teaching you how to build a bidirectional LSTM in Pytorch!

LSTM Text Classification Using Pytorch

Photo by Christopher Gower on Unsplash

Intro

Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification , this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.

First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network , which belongs to a larger category of neural networks called Recurrent Neural Network (RNN) . Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.

LSTM Text Classification Using Pytorch

LSTM Cell

Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:

  1. Preprocess Dataset
  2. Importing Libraries
  3. Load Dataset
  4. Build Model
  5. Training
  6. Evaluation

Before we dive right into the tutorial, here is where you can access the code in this article:

Step 1: Preprocess Dataset

The raw dataset looks like the following:

LSTM Text Classification Using Pytorch

Dataset Overview

The dataset contains an arbitrary index, title, text, and the corresponding label.

For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string function which will be used to cut each sentence to the first first_n_words words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words , and split the dataset according to train_test_ratio and train_valid_ratio . We save the resulting dataframes into .csv files, getting train.csv , valid.csv , and test.csv .

Step 2: Importing Libraries

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

Step 3: Load Dataset

First, we use torchText to create a label field for the label in our dataset and a text field for the title , text , and titletext . We then build a TabularDataset by pointing it to the path containing the train.csv , valid.csv , and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).

Step 4: Build Model

We construct the LSTM class that inherits from the nn.Module . Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).

Step 5: Training

Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.

We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:

LSTM Text Classification Using Pytorch

The whole training process was fast on Google Colab. It took less than two minutes to train!

Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.

LSTM Text Classification Using Pytorch

Step 6: Evaluation

Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.

LSTM Text Classification Using Pytorch

We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.

Conclusion

This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification !

If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)

References

[1] S. Hochreiter, J. Schmidhuber, Long Short-Term Memory (1997), Neural Computation


以上所述就是小编给大家介绍的《LSTM Text Classification Using Pytorch》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

程序员修炼之道(影印版)

程序员修炼之道(影印版)

Andrew Hunt、David Thomas / 中国电力出版社 / 2003-8-1 / 39.00

本书直击编程陈地,穿过了软件开发中日益增长的规范和技术藩篱,对核心过程进行了审视——即根据需求,创建用户乐于接受的、可工作和易维护的代码。本书包含的内容从个人责任到职业发展,直至保持代码灵活和易于改编重用的架构技术。从本书中将学到防止软件变质、消除复制知识的陷阱、编写灵活、动态和易适应的代码、避免出现相同的设计、用契约、断言和异常对代码进行防护等内容。一起来看看 《程序员修炼之道(影印版)》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

html转js在线工具
html转js在线工具

html转js在线工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具