Licencia
[中文 | Inglés]
Implementación no oficial de XLNet. La extracción de incrustación y el extracto de incrustación con memoria muestran cómo obtener las salidas de la última capa del transformador utilizando puntos de control previamente capacitados.
pip install keras-xlnetHaga clic en el nombre de la tarea para ver las demostraciones con el modelo base:
| Nombre de la tarea | Métrica | Resultados aproximados en el conjunto de desarrollo |
|---|---|---|
| Reajuste salarial | Matthew Corr. | 52 |
| SST-2 | Exactitud | 93 |
| MRPC | Precisión/F1 | 86/89 |
| STS-B | Pearson Corr. / Spearman Corr. | 86/87 |
| QQP | Precisión/F1 | 90/86 |
| Mnli | Exactitud | 84/84 |
| Qnli | Exactitud | 86 |
| RTE | Exactitud | 64 |
| WNLI | Exactitud | 56 |
(Solo se predicen 0s en el conjunto de datos WNLI)
import os
from keras_xlnet import Tokenizer , load_trained_model_from_checkpoint , ATTENTION_TYPE_BI
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
tokenizer = Tokenizer ( os . path . join ( checkpoint_path , 'spiece.model' ))
model = load_trained_model_from_checkpoint (
config_path = os . path . join ( checkpoint_path , 'xlnet_config.json' ),
checkpoint_path = os . path . join ( checkpoint_path , 'xlnet_model.ckpt' ),
batch_size = 16 ,
memory_len = 512 ,
target_len = 128 ,
in_train_phase = False ,
attention_type = ATTENTION_TYPE_BI ,
)
model . summary () Los argumentos batch_size , memory_len y target_len son los tamaños máximos utilizados para la inicialización de recuerdos. El modelo utilizado para capacitar a un modelo de lenguaje se devuelve si in_train_phase es True , de lo contrario, se devolverá un modelo utilizado para el ajuste fino.
Tenga en cuenta que shuffle debe ser False en fit o fit_generator si se usan recuerdos.
in_train_phase es False3 entradas:
(batch_size, target_len) .(batch_size, target_len) .(batch_size, 1) .1 salida:
(batch_size, target_len, units) . in_train_phase es True4 entradas:
(batch_size, target_len) .(batch_size, target_len) .(batch_size, 1) .(batch_size, target_len) .1 salida:
(batch_size, target_len, num_token) .