BERT Text Classification Using Pytorch

栏目: IT技术 · 发布时间: 4年前

BERT Text Classification Using Pytorch

Photo by Clément H on Unsplash

Intro

Text classificationis one of the most common tasks in NLP. It is applied in a wide variety of applications, including sentiment analysis, spam filtering, news categorization, etc. Here, we show you how you can detect fake news (classifying an article as REAL or FAKE) using the state-of-the-art models, a tutorial that can be extended to really any text classification task.

The Transformer is the basic building block of most current state-of-the-art architectures of NLP. Its primary advantage is its multi-head attention mechanisms which allow for an increase in performance and significantly more parallelization than previous competing models such as recurrent neural networks. In this tutorial, we will use pre-trained BERT , one of the most popular transformer models, and fine-tune it on fake news detection.

The main source code of this article is available in this Google Colab Notebook .

The preprocessing code is also available in this Google Colab Notebook .

Getting Started

Huggingface is the most well-known library for implementing state-of-the-art transformers in Python. It offers clear documentation and tutorials on implementing dozens of different transformers for a wide variety of different tasks. We will be using Pytorch so make sure Pytorch is installed. After ensuring relevant libraries are installed, you can install the transformers library by:

pip install transformers

For the dataset, we will be using the REAL and FAKE News Dataset from Kaggle.

Step 1: Importing Libraries

The most important library to note here is that we imported BERTokenizer and BERTSequenceClassification to construct the tokenizer and model later on.

Step 2: Preprocess and Prepare Dataset

In the original dataset, we added an additional TitleText column which is the concatenation of title and text. We want to test whether an article is fake using both the title and the text.

For the tokenizer, we use the “bert-base-uncased” version of BertTokenizer. Using TorchText , we first create the Text Field and the Label Field. The Text Field will be used for containing the news articles and the Label is the true target. We limit each article to the first 128 tokens for BERT input. Then, we create a TabularDataset from our dataset csv files using the two Fields to produce the train, validation, and test sets. Then we create Iterators to prepare them in batches.

Note: In order to use BERT tokenizer with TorchText, we have to set use_vocab=False and tokenize=tokenizer.encode . This will let TorchText know that we will not be building our own vocabulary using our dataset from scratch, but instead, use the pre-trained BERT tokenizer and its corresponding word-to-index mapping.

Step 3: Build Model

We are using the “bert-base-uncased” version of BERT, which is the smaller model trained on lower-cased English text (with 12-layer, 768-hidden, 12-heads, 110M parameters). Check out Huggingface’s documentation for other versions of BERT or other transformer models.

Step 4: Training

We write save and load functions for model checkpoints and training metrics, respectively. Note that the save function for model checkpoint does not save the optimizer. We do not save the optimizer because the optimizer normally takes very large storage space and we assume no training from a previous checkpoint is needed. The training metric stores the training loss, validation loss, and global steps so that visualizations regarding the training process can be made later.

We use Adam optimizer and a suitable learning rate to tune BERT for 5 epochs.

We use BinaryCrossEntropy as the loss function since fake news detection is a two-class problem. Make sure the output is passed through Sigmoid before calculating the loss between the target and itself.

During training, we evaluate our model parameters against the validation set. We save the model each time the validation loss decreases so that we end up with the model with the lowest validation loss, which can be considered as the best model. Here are the outputs during training:

BERT Text Classification Using Pytorch

Image by author

After training, we can plot a diagram using the code below:

BERT Text Classification Using Pytorch

Image by author

Step 5: Evaluation

For evaluation, we predict the articles using our trained model and evaluate it against the true label. We print out classification report which includes test accuracy, precision, recall, F1-score. We also print out the confusion matrix to see how much data our model predicts correctly and incorrectly for each class.

BERT Text Classification Using Pytorch

Image by author

After evaluating our model, we find that our model achieves an impressive accuracy of 96.99%!

Conclusion

We find that fine-tuning BERT performs extremely well on our dataset and is really simple to implement thanks to the open-source Huggingface Transformers library. This can be extended to any text classification dataset without any hassle.

References

[1] A. Vaswani, N. Shazeer, N. Parmar, etc., Attention Is All You Need (2017), 31st Conference on Neural Information Processing Systems

[2] J. Devlin, M. Chang, K. Lee and K. Toutanova, BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (2019), 2019 Annual Conference of the North American Chapter of the Association for Computational Linguistics


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

矩阵计算

矩阵计算

Gene H.Golub、Charles F.Van Loan / 袁亚湘 / 人民邮电出版社 / 2011-3-1 / 89.00元

本书是国际上数值计算方面的权威著作,有“圣经”之称。被美国加州大学、斯坦福大学、华盛顿大学、芝加哥大学、中国科学院研究生院等很多世界知名学府用作相关课程的教材或主要参考书。 本书系统地介绍了矩阵计算的基本理论和方法。书中的许多算法都有现成的软件包实现,每节后还附有习题,并有注释和大量参考文献,非常有助于自学。一起来看看 《矩阵计算》 这本书的介绍吧!

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

SHA 加密
SHA 加密

SHA 加密工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具