中文文本分類,基於pytorch,開箱即用。
神經網絡模型:TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention, DPCNN, Transformer
預訓練模型:Bert,ERNIE
模型介紹、數據流動過程:參考
數據以字為單位輸入模型,預訓練詞向量使用搜狗新聞Word+Character 300d,點這裡下載
| 模型 | 介紹 |
|---|---|
| TextCNN | Kim 2014 經典的CNN文本分類 |
| TextRNN | BiLSTM |
| TextRNN_Att | BiLSTM+Attention |
| TextRCNN | BiLSTM+池化 |
| FastText | bow+bigram+trigram, 效果出奇的好 |
| DPCNN | 深層金字塔CNN |
| Transformer | 效果較差 |
| 模型 | 介紹 | 備註 |
|---|---|---|
| bert | 原始的bert | |
| ERNIE | ERNIE | |
| bert_CNN | bert作為Embedding層,接入三種卷積核的CNN | bert + CNN |
| bert_RNN | bert作為Embedding層,接入LSTM | bert + RNN |
| bert_RCNN | bert作為Embedding層,通過LSTM與bert輸出拼接,經過一層最大池化層 | bert + RCNN |
| bert_DPCNN | bert作為Embedding層,經過一個包含三個不同卷積特徵提取器的region embedding層,可以看作輸出的是embedding,然後經過兩層的等長卷積來為接下來的特徵抽取提供更寬的感受眼,(提高embdding的豐富性),然後會重複通過一個1/2池化的殘差塊,1/2池化不斷提高詞位的語義,其中固定了feature_maps,殘差網絡的引入是為了解決在訓練的過程中梯度消失和梯度爆炸的問題。 | bert + DPCNN |
參考:
python 3.7
pytorch 1.1
tqdm
sklearn
tensorboardX pytorch_pretrained_bert (預訓練代碼也上傳了, 不需要這個庫了)
我從THUCNews中抽取了20萬條新聞標題,已上傳至github,文本長度在20到30之間。一共10個類別,每類2萬條。數據以字為單位輸入模型。
類別:財經、房產、股票、教育、科技、社會、時政、體育、遊戲、娛樂。
數據集劃分:
| 數據集 | 數據量 |
|---|---|
| 訓練集 | 18萬 |
| 驗證集 | 1萬 |
| 測試集 | 1萬 |
python run.py --model TextCNN --word True機器:一塊2080Ti , 訓練時間:30分鐘。
| 模型 | acc | 備註 |
|---|---|---|
| TextCNN | 91.22% | Kim 2014 經典的CNN文本分類 |
| TextRNN | 91.12% | BiLSTM |
| TextRNN_Att | 90.90% | BiLSTM+Attention |
| TextRCNN | 91.54% | BiLSTM+池化 |
| FastText | 92.23% | bow+bigram+trigram, 效果出奇的好 |
| DPCNN | 91.25% | 深層金字塔CNN |
| Transformer | 89.91% | 效果較差 |
| bert | 94.83% | 單純的bert |
| ERNIE | 94.61% | 說好的中文碾壓bert呢 |
| bert_CNN | 94.44% | bert + CNN |
| bert_RNN | 94.57% | bert + RNN |
| bert_RCNN | 94.51% | bert + RCNN |
| bert_DPCNN | 94.47% | bert + DPCNN |
原始的bert效果就很好了,把bert當作embedding層送入其它模型,效果反而降了,之後會嘗試長文本的效果對比。
bert模型放在bert_pretain目錄下,ERNIE模型放在ERNIE_pretrain目錄下,每個目錄下都是三個文件:
預訓練模型下載地址:
bert_Chinese: 模型https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
詞表https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
來自這裡
備用:模型的網盤地址:https://pan.baidu.com/s/1qSAD5gwClq7xlgzl_4W3Pw
ERNIE_Chinese: https://pan.baidu.com/s/1lEPdDN1-YQJmKEd_g9rLgw
來自這裡
解壓後,按照上面說的放在對應目錄下,文件名稱確認無誤即可。
# 训练并测试:
# TextCNN
python run.py --model TextCNN
# TextRNN
python run.py --model TextRNN
# TextRNN_Att
python run.py --model TextRNN_Att
# TextRCNN
python run.py --model TextRCNN
# FastText, embedding层是随机初始化的
python run.py --model FastText --embedding random
# DPCNN
python run.py --model DPCNN
# Transformer
python run.py --model Transformer
下載好預訓練模型就可以跑了:
# 预训练模型训练并测试:
# bert
python pretrain_run.py --model bert
# bert + 其它
python pretrain_run.py --model bert_CNN
# ERNIE
python pretrain_run.py --model ERNIE
預訓練模型:
python pretrain_predict.py
神經網絡模型:
python predict.py
模型都在models目錄下,超參定義和模型定義在同一文件中。
[1] Convolutional Neural Networks for Sentence Classification
[2] Recurrent Neural Network for Text Classification with Multi-Task Learning
[3] Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification
[4] Recurrent Convolutional Neural Networks for Text Classification
[5] Bag of Tricks for Efficient Text Classification
[6] Deep Pyramid Convolutional Neural Networks for Text Categorization
[7] Attention Is All You Need
[8] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[9] ERNIE: Enhanced Representation through Knowledge Integration
本項目基於以下倉庫繼續開發優化: