cnn text classification pytorch
1.0.0
Esta es la implementación de las redes neuronales convolucionales de Kim para el documento de clasificación de oraciones en Pytorch.
Acabo de probar dos conjuntos de datos, Mr y SST.
| Conjunto de datos | Tamaño de clase | Mejor resultado | Resultado del papel de Kim |
|---|---|---|---|
| SEÑOR | 2 | 77.5%(CNN-Rand-Static) | 76.1%(CNN-rand-ostatic) |
| SST | 5 | 37.2%(CNN-Rand-Static) | 45.0%(CNN-rand-nostatic) |
No he ajustado los hiperparametros para SST en serio.
./main.py -h
o
python3 main.py -h
Obtendrás:
CNN text classificer
optional arguments:
-h, --help show this help message and exit
-batch-size N batch size for training [default: 50]
-lr LR initial learning rate [default: 0.01]
-epochs N number of epochs for train [default: 10]
-dropout the probability for dropout [default: 0.5]
-max_norm MAX_NORM l2 constraint of parameters
-cpu disable the gpu
-device DEVICE device to use for iterate data
-embed-dim EMBED_DIM
-static fix the embedding
-kernel-sizes KERNEL_SIZES
Comma-separated kernel size to use for convolution
-kernel-num KERNEL_NUM
number of each kind of kernel
-class-num CLASS_NUM number of class
-shuffle shuffle the data every epoch
-num-workers NUM_WORKERS
how many subprocesses to use for data loading
[default: 0]
-log-interval LOG_INTERVAL
how many batches to wait before logging training
status
-test-interval TEST_INTERVAL
how many epochs to wait before testing
-save-interval SAVE_INTERVAL
how many epochs to wait before saving
-predict PREDICT predict the sentence given
-snapshot SNAPSHOT filename of model snapshot [default: None]
-save-dir SAVE_DIR where to save the checkpoint
./main.py
Obtendrás:
Batch[100] - loss: 0.655424 acc: 59.3750%
Evaluation - loss: 0.672396 acc: 57.6923%(615/1066)
Si ha construido su conjunto de pruebas, realiza pruebas como:
/main.py -test -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt
La opción de instantánea significa de dónde se carga su modelo. Si no lo asigna, el modelo comenzará desde cero.
Ejemplo1
./main.py -predict="Hello my dear , I love you so much ."
-snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt"
Obtendrás:
Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
[Text] Hello my dear , I love you so much .
[Label] positive
Ejemplo2
./main.py -predict="You just make me so sad and I have to leave you ."
-snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt"
Obtendrás:
Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
[Text] You just make me so sad and I have to leave you .
[Label] negative
Su texto debe estar separado por el espacio, incluso la puntuación. Y su texto debe más tiempo que el tamaño máximo del núcleo.