Chinese text classification, based on pytorch, and is available out of the box.
Neural network models: TextCNN, TextRNN, FastText, TextRCNN, BiLSTM_Attention, DPCNN, Transformer
Pre-trained model: Bert, ERNIE
Model introduction, data flow process: reference
The data is input into the model in units of words, and the pre-trained word vector uses Sogou News Word+Character 300d, click here to download
| Model | introduce |
|---|---|
| TextCNN | Kim 2014 Classic CNN Text Classification |
| TextRNN | BiLSTM |
| TextRNN_Att | BiLSTM+Attention |
| TextRCNN | BiLSTM+pooling |
| FastText | bow+bigram+trigram, the effect is surprisingly good |
| DPCNN | Deep Pyramid CNN |
| Transformer | Poor results |
| Model | introduce | Remark |
|---|---|---|
| bert | Original bert | |
| ERNIE | ERNIE | |
| bert_CNN | As the Embedding layer, bert connects to the CNN of three convolution kernels. | bert + CNN |
| bert_RNN | bert as the Embedding layer, access to LSTM | bert + RNN |
| bert_RCNN | As the Embedding layer, bert is spliced with bert output through LSTM and passes through a maximum pooling layer. | bert + RCNN |
| bert_DPCNN | As the Embedding layer, bert passes through a region embedding layer containing three different convolution feature extractors, which can be regarded as the output embedding, and then through the equal length convolution of two layers to provide a wider sensory eye for the subsequent feature extraction (improving the richness of embdding), and then it will repeatedly pass a 1/2 pooling residual block. The 1/2 pooling continuously improves the semantics of the word position, and the feature_maps is fixed. The introduction of the residual network is to solve the problems of gradient disappearance and gradient explosion during the training process. | bert + DPCNN |
refer to:
python 3.7
pytorch 1.1
tqdm
sklearn
tensorboardX pytorch_pretrained_bert (The pre-training code has also been uploaded, and this library is not needed)
I've drawn 200,000 news titles from THUCNews, uploaded to github, with text lengths between 20 and 30. There are 10 categories in total, with 20,000 items in each category. Data is entered into the model in words.
Category: Finance, Real Estate, Stocks, Education, Science and Technology, Society, Current Affairs, Sports, Games, Entertainment.
Dataset division:
| Dataset | Data volume |
|---|---|
| Training set | 180,000 |
| Verification Set | 10,000 |
| Test set | 10,000 |
python run.py --model TextCNN --word TrueMachine: One piece of 2080Ti, training time: 30 minutes.
| Model | acc | Remark |
|---|---|---|
| TextCNN | 91.22% | Kim 2014 Classic CNN Text Classification |
| TextRNN | 91.12% | BiLSTM |
| TextRNN_Att | 90.90% | BiLSTM+Attention |
| TextRCNN | 91.54% | BiLSTM+pooling |
| FastText | 92.23% | bow+bigram+trigram, the effect is surprisingly good |
| DPCNN | 91.25% | Deep Pyramid CNN |
| Transformer | 89.91% | Poor results |
| bert | 94.83% | Simple bert |
| ERNIE | 94.61% | What is the promised Chinese crushing bert |
| bert_CNN | 94.44% | bert + CNN |
| bert_RNN | 94.57% | bert + RNN |
| bert_RCNN | 94.51% | bert + RCNN |
| bert_DPCNN | 94.47% | bert + DPCNN |
The original bert effect is very good. If you use bert as an embedding layer and send it to other models, the effect will be reduced. Later, you will try to compare the effect of long text.
The bert model is placed in the bert_pretain directory, and the ERNIE model is placed in the ERNIE_pretrain directory. Each directory has three files:
Pre-trained model download address:
bert_Chinese: Model https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
Vocabulary https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
From here
Alternate: The network disk address of the model: https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw
ERNIE_Chinese: https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw
From here
After decompression, put it in the corresponding directory as mentioned above, and confirm the file name is correct.
# 训练并测试:
# TextCNN
python run.py --model TextCNN
# TextRNN
python run.py --model TextRNN
# TextRNN_Att
python run.py --model TextRNN_Att
# TextRCNN
python run.py --model TextRCNN
# FastText, embedding层是随机初始化的
python run.py --model FastText --embedding random
# DPCNN
python run.py --model DPCNN
# Transformer
python run.py --model Transformer
Download the pre-trained model and run:
# 预训练模型训练并测试:
# bert
python pretrain_run.py --model bert
# bert + 其它
python pretrain_run.py --model bert_CNN
# ERNIE
python pretrain_run.py --model ERNIE
Pre-trained model:
python pretrain_predict.py
Neural Network Model:
python predict.py
The models are all in the models directory, and the hyperparameter definition and the model definition are in the same file.
[1] Convolutional Neural Networks for Sentence Classification
[2] Recurrent Neural Network for Text Classification with Multi-Task Learning
[3] Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification
[4] Recurrent Convolutional Neural Networks for Text Classification
[5] Bag of Tricks for Efficient Text Classification
[6] Deep Pyramid Convolutional Neural Networks for Text Categorization
[7] Attention Is All You Need
[8] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[9] ERNIE: Enhanced Representation through Knowledge Integration
This project continues to develop and optimize based on the following warehouses: