Pytorchに基づいた中国のテキスト分類は、箱から出して利用できます。
ニューラルネットワークモデル:TextCnn、Textrnn、FastText、Textrcnn、bilstm_attention、dpcnn、トランス
事前訓練を受けたモデル:バート、アーニー
モデルの紹介、データフロープロセス:参照
データは単語の単位でモデルに入力され、事前に訓練された単語ベクトルはSogou News Word+Character300Dを使用しています。ここをクリックしてダウンロードしてください
| モデル | 導入 |
|---|---|
| textcnn | キム2014クラシックCNNテキスト分類 |
| Textrnn | bilstm |
| textrnn_att | bilstm+注意 |
| textrcnn | bilstm+プーリング |
| fastText | Bow+Bigram+Trigram、その効果は驚くほど良いです |
| dpcnn | 深いピラミッドCNN |
| トランス | 結果が悪い |
| モデル | 導入 | 述べる |
|---|---|---|
| バート | オリジナルのバート | |
| アーニー | アーニー | |
| bert_cnn | 埋め込み層として、Bertは3つの畳み込みカーネルのCNNに接続します。 | bert + cnn |
| bert_rnn | 埋め込み層としてのバート、LSTMへのアクセス | bert + rnn |
| bert_rcnn | 埋め込み層として、BertはLSTMを介してBert出力でスプライスされ、最大プーリング層を通過します。 | bert + rcnn |
| bert_dpcnn | 埋め込み層として、BERTは、出力埋め込み剤と見なすことができる3つの異なる畳み込み特徴抽出器を含む領域埋め込み層を通過し、次に2つの層の等しい長さの畳み込みを通過して、その後の特徴抽出に広い感覚目を提供します(胚の豊富さを改善します)。 1/2プーリングは、単語の位置のセマンティクスを継続的に改善し、feature_mapsが修正されます。残留ネットワークの導入は、トレーニングプロセス中の勾配消失と勾配爆発の問題を解決することです。 | bert + dpcnn |
参照:
Python 3.7
Pytorch 1.1
TQDM
Sklearn
tensorboardx pytorch_pretrained_bert (トレーニング前コードもアップロードされており、このライブラリは必要ありません)
Thucnewsから200,000のニュースタイトルを描きました。Githubにアップロードされ、テキストの長さは20〜30です。合計10のカテゴリがあり、各カテゴリには20,000のアイテムがあります。データは言葉でモデルに入力されます。
カテゴリ:財務、不動産、株式、教育、科学技術、社会、時事問題、スポーツ、ゲーム、エンターテイメント。
データセット部門:
| データセット | データボリューム |
|---|---|
| トレーニングセット | 180,000 |
| 検証セット | 10,000 |
| テストセット | 10,000 |
python run.py --model TextCNN --word Trueマシン:2080tiの1つ、トレーニング時間:30分。
| モデル | acc | 述べる |
|---|---|---|
| textcnn | 91.22% | キム2014クラシックCNNテキスト分類 |
| Textrnn | 91.12% | bilstm |
| textrnn_att | 90.90% | bilstm+注意 |
| textrcnn | 91.54% | bilstm+プーリング |
| fastText | 92.23% | Bow+Bigram+Trigram、その効果は驚くほど良いです |
| dpcnn | 91.25% | 深いピラミッドCNN |
| トランス | 89.91% | 結果が悪い |
| バート | 94.83% | シンプルなバート |
| アーニー | 94.61% | 約束された中国の粉砕バートは何ですか |
| bert_cnn | 94.44% | bert + cnn |
| bert_rnn | 94.57% | bert + rnn |
| bert_rcnn | 94.51% | bert + rcnn |
| bert_dpcnn | 94.47% | bert + dpcnn |
元のバート効果はとても良いです。 Bertを埋め込み層として使用し、他のモデルに送信すると、効果が低下します。後で、長いテキストの効果を比較しようとします。
BERTモデルはbert_pretainディレクトリに配置され、ernieモデルはernie_pretrainディレクトリに配置されます。各ディレクトリには3つのファイルがあります。
事前に訓練されたモデルのダウンロードアドレス:
bert_chinese:モデルhttps://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
語彙https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
ここから
代替:モデルのネットワークディスクアドレス:https://pan.baidu.com/s/1qsad5gwclq7xlgzl_4w3pw
ernie_chinese:https://pan.baidu.com/s/1lepddn1-yqjmked_g9rlgw
ここから
減圧後、上記のように対応するディレクトリに入れ、ファイル名が正しいことを確認します。
# 训练并测试:
# TextCNN
python run.py --model TextCNN
# TextRNN
python run.py --model TextRNN
# TextRNN_Att
python run.py --model TextRNN_Att
# TextRCNN
python run.py --model TextRCNN
# FastText, embedding层是随机初始化的
python run.py --model FastText --embedding random
# DPCNN
python run.py --model DPCNN
# Transformer
python run.py --model Transformer
事前に訓練されたモデルをダウンロードして実行します。
# 预训练模型训练并测试:
# bert
python pretrain_run.py --model bert
# bert + 其它
python pretrain_run.py --model bert_CNN
# ERNIE
python pretrain_run.py --model ERNIE
事前に訓練されたモデル:
python pretrain_predict.py
ニューラルネットワークモデル:
python predict.py
モデルはすべてモデルディレクトリにあり、ハイパーパラメーター定義とモデル定義は同じファイルにあります。
[1]文化分類のための畳み込みニューラルネットワーク
[2]マルチタスク学習によるテキスト分類のための再発ニューラルネットワーク
[3]関係分類のための注意ベースの双方向短期メモリネットワーク
[4]テキスト分類のための再発畳み込みニューラルネットワーク
[5]効率的なテキスト分類のためのトリックの袋
[6]テキスト分類のための深いピラミッド畳み込みニューラルネットワーク
[7]注意が必要です
[8]バート:言語理解のための深い双方向変圧器の事前訓練
[9]アーニー:知識統合による表現の強化
このプロジェクトは、次の倉庫に基づいて開発および最適化を続けています。