Modelo para resolver el problema de la anotación de tipo columna con Bert, entrenado en el conjunto de datos RWT-Rutabert.
El conjunto de datos RWT-Rutabert contiene 1 441 349 columnas de las tablas de Wikipedia de idioma ruso. Con encabezados que coinciden con 170 tipos semánticos dbpedia. Tiene una división de tren / prueba fija:
| Dividir | Columnas | Mesas | Avg. columnas por tabla |
|---|---|---|---|
| Prueba | 115 448 | 55 080 | 2.096 |
| Tren | 1 325 901 | 633 426 | 2.093 |
Entrenamos a Rutabert con dos estrategias de serialización de mesa:
Resultados de referencia en el conjunto de datos RWT-Rutabert:
| Estrategia de serialización | micro-F1 | macro-F1 | pesado-f1 |
|---|---|---|---|
| Multicolumn | 0.962 | 0.891 | 0.9621 |
| Columna vecina | 0.964 | 0.904 | 0.9639 |
Parámetros de entrenamiento:
| Parámetro | Valor |
|---|---|
| tamaño por lotes | 32 |
| épocas | 30 |
| Función de pérdida | Entrelazamiento |
| Optimizador de GD | ADAMW (LR = 5E-5, EPS = 1E-8) |
| GPU | 4 NVIDIA A100 (80 GB) |
| semilla aleatoria | 2024 |
| división de validación | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
La configuración del modelo se puede encontrar en el archivo config.json .
Los parámetros de argumento configuratatoin se enumeran a continuación:
| argumento | descripción |
|---|---|
| num_labels | Número de etiquetas utilizadas para la clasificación |
| num_gpu | Número de GPU para usar |
| save_period_in_epochs | Número que caracteriza con qué periodicidad se guarda el punto de control (en épocas) |
| métrica | Las métricas de clasificación utilizadas son |
| Pretrense_model_name | Nombre de acceso directo de Bert de Huggingface |
| table_serialization_type | Método de serializar una tabla en una secuencia |
| lote_size | Tamaño por lotes |
| num_epochs | Número de épocas de entrenamiento |
| sweat | Semilla aleatoria |
| logs_dir | Directorio para registro |
| Train_log_filename | Nombre del archivo para registro de trenes |
| test_log_filename | Nombre del archivo para el registro de pruebas |
| start_from_checkpoint | Bandera para comenzar a entrenar desde el punto de control |
| Checkpoint_dir | Directorio para almacenar puntos de control del modelo |
| checkpoint_name | Nombre del archivo de un punto de control (estado del modelo) |
| inferencia_model_name | Nombre de archivo de un modelo de inferencia |
| inferencia_dir | Directorio para almacenar tablas de inferencia .csv |
| dataloader.valid_split | Cantidad de división de subconjunto de validación |
| dataLoader.num_workers | Número de trabajadores de dataloader |
| DataSet.num_rows | Número de filas legibles en el conjunto de datos, si null lee todas las filas en los archivos |
| dataSet.data_dir | Directorio para almacenar archivos de tren/prueba/inferencia |
| DataSet.train_path | Directorio para almacenar archivos del conjunto de datos de trenes .csv |
| DataSet.test_path | Direcotry para almacenar archivos de conjunto de datos de prueba .csv |
Recomendamos para cambiar solo los parámetros:
num_gpu - Cualquier número positivo de Ingeter + {0}. 0 Poner en capacitación / prueba en CPU.save_period_in_epochs : cualquier número entero positivo, mide en épocas.table_serialization_type - "column_wise" o "table_wise".pretrained_model_name - Nombres de Bert Shorcut de Huggingface Pytorch Modelos Pretrados.batch_size : cualquier número entero positivo.num_epochs : cualquier número entero positivo.random_seed : cualquier número entero.start_from_checkpoint - "verdadero" o "falso".checkpoint_name : cualquier nombre del modelo, guardado en el directorio checkpoint .inference_model_name : cualquier nombre del modelo, guardado en el directorio checkpoint . Pero recomendamos usar los mejores modelos: [model_best_f1_weuth.pt, model_best_f1_macro.pt, model_best_f1_micro.pt].dataloader.valid_split - Número real dentro del rango [0.0, 1.0] (0.0 representa el 0 % del subconjunto de trenes, 0.5 representa el 50 % del subconjunto del tren). O número entero positivo (denotando un número fijo de un subconjunto de validación).dataset.num_rows - "NULL" significa leer todas las líneas en archivos de conjunto de datos. Integer positivo significa el número de líneas para leer en los archivos del conjunto de datos. Antes de entrenar / probar el modelo que necesita:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh desde el repositorio de datos de datos, para mover los archivos de conjunto de datos al directorio data de Rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json antes del entrenamiento. Rutabert admite capacitación / pruebas localmente y dentro del contenedor Docker. También es compatible con Slurm Workload Manager.
RuTaBERT$ virtualenv venvo
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ directorio ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Requisitos:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ directorio ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvo
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ directorio ( train.log , test.log , error_train.log , error_test.log ). data/test .RuTaBERT$ ./download.sh table_wiseo
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ directorio ( test.log , error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseo
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv