Modèle de résolution du problème de l'annotation de type colonne avec Bert, formé sur un ensemble de données RWT-Rutabert.
L'ensemble de données RWT-RUTABERT contient 1 441 349 colonnes à partir de tables Wikipedia en langue russe. Avec des en-têtes correspondant aux types sémantiques de 170 dbpedia. Il a une fraction de train / test fixe:
| Diviser | Colonnes | Tables | Avg. colonnes par table |
|---|---|---|---|
| Test | 115 448 | 55 080 | 2.096 |
| Former | 1 325 901 | 633 426 | 2.093 |
Nous avons formé Rutabert avec deux stratégies de sérialisation de table:
Résultats de référence sur l'ensemble de données RWT-Rutabert:
| Stratégie de sérialisation | micro-f1 | macro-f1 | pondéré-F1 |
|---|---|---|---|
| Multicolonne | 0,962 | 0,891 | 0,9621 |
| Colonne voisine | 0,964 | 0,904 | 0,9639 |
Paramètres de formation:
| Paramètre | Valeur |
|---|---|
| taille de lot | 32 |
| époques | 30 |
| Fonction de perte | Entropie croisée |
| Optimiseur GD | ADAMW (LR = 5E-5, EPS = 1E-8) |
| GPU | 4 Nvidia A100 (80 Go) |
| semences aléatoires | 2024 |
| Split de validation | 5% |
?RuTaBERT
┣ checkpoints
┃ ┗ Saved PyTorch models `.pt`
┣ data
┃ ┣ inference
┃ ┃ ┗ Tabels to inference `.csv`
┃ ┣ test
┃ ┃ ┗ Test dataset files `.csv`
┃ ┣ train
┃ ┃ ┗ Train dataset files `.csv`
┃ ┗ Directory for storing dataset files.
┣ dataset
┃ ┗ Dataset wrapper classes, dataloaders
┣ logs
┃ ┗ Log files (train / test / error)
┣ model
┃ ┗ Model and metrics
┣ trainer
┃ ┗ Trainer
┣ utils
┃ ┗ Helper functions
┗ Entry points (train.py, test.py, inference.py), configuration, etc.
La configuration du modèle peut être trouvée dans le fichier config.json .
Les paramètres d'argument Configuratoin sont répertoriés ci-dessous:
| argument | description |
|---|---|
| num_labels | Nombre d'étiquettes utilisées pour la classification |
| num_gpu | Nombre de GPU à utiliser |
| Save_period_in_epochs | Numéro caractérisant avec la périodicité du point de contrôle enregistré (en époques) |
| métrique | Les mesures de classification utilisées sont |
| Pretrained_Model_name | Nom de raccourci Bert de Huggingface |
| table_serialization_type | Méthode de sérialisation d'une table en une séquence |
| batch_size | Taille de lot |
| num_pochs | Nombre d'époches de formation |
| Random_seed | Semences aléatoires |
| LOGS_DIR | Répertoire de l'exploitation forestière |
| train_log_filename | Nom de fichier pour journalisation du train |
| test_log_filename | Nom de fichier pour la journalisation du test |
| start_from_checkpoint | Drapeau pour commencer la formation à partir de Checkpoint |
| Checkpoint_dir | Répertoire pour stocker les points de contrôle du modèle |
| Checkpoint_name | Nom de fichier d'un point de contrôle (État du modèle) |
| INFERGE_MODEL_NAME | Nom de fichier d'un modèle d'inférence |
| inférence_dir | Répertoire pour stocker les tableaux d'inférence .csv |
| dataloder.valid_split | Montant de la division du sous-ensemble de validation |
| dataloder.num_workers | Nombre de travailleurs de dataloader |
| dataset.num_rows | Nombre de lignes lisibles dans l'ensemble de données, si null lisez toutes les lignes dans les fichiers |
| dataset.data_dir | Répertoire pour stocker des fichiers de train / test / inférence |
| dataset.train_path | Répertoire pour stocker les fichiers de jeu de données de train .csv |
| dataset.test_path | Direcotry pour stocker les fichiers de jeu de données de test .csv |
Nous recommandons pour changer uniquement les paramètres:
num_gpu - tout numéro d'ingéter positif + {0}. 0 défendre la formation / les tests sur le processeur.save_period_in_epochs - Tout numéro entier positif, mesures en époques.table_serialization_type - "Column_wise" ou "Table_wise".pretrained_model_name - Noms Shorcut Bert à partir de modèles de prétraitement pytorch HuggingFace.batch_size - tout numéro entier positif.num_epochs - Tout numéro entier positif.random_seed - tout numéro entier.start_from_checkpoint - "true" ou "false".checkpoint_name - Tout nom du modèle, enregistré dans le répertoire checkpoint .inference_model_name - Tout nom du modèle, enregistré dans le répertoire checkpoint . Mais nous vous recommandons d'utiliser les meilleurs modèles: [Model_Best_F1_WeEmple.PT, Model_Best_F1_Macro.PT, Model_Best_F1_Micro.PT].dataloader.valid_split - Nombre réel dans la plage [0,0, 1,0] (0,0 signifie 0% du sous-ensemble de train, 0,5 représente 50% du sous-ensemble de train). Ou numéro entier positif (indiquant un nombre fixe d'un sous-ensemble de validation).dataset.num_rows - "NULL" signifie Lire toutes les lignes dans les fichiers de jeu de données. Un entier positif signifie le nombre de lignes à lire dans les fichiers de l'ensemble de données. Avant la formation / tester le modèle dont vous avez besoin:
├── src
│ ├── RuTaBERT
│ ├── RuTaBERT-Dataset
│ │ ├── move_dataset.sh
move_dataset.sh à partir du référentiel de jeu de données, pour déplacer les fichiers de jeu de données dans le répertoire data Rutabert: RuTaBERT-Dataset$ ./move_dataset.shconfig.json avant la formation. Rutabert prend en charge la formation / les tests localement et à l'intérieur du conteneur Docker. Prend également en charge Slurm Workload Manager.
RuTaBERT$ virtualenv venvou
RuTaBERT$ python -m virtualenv venvRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 train.py 2> logs/error_train.log &&
python3 test.py 2> logs/error_test.logcheckpoint .logs/ répertoire ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).Exigences:
RuTaBERT$ sudo docker build -t rutabert .RuTaBERT$ sudo docker run -d --runtime=nvidia --gpus=all
--mount source=rutabert_logs,target=/app/rutabert/logs
--mount source=rutabert_checkpoints,target=/app/rutabert/checkpoints
rutabertRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_checkpoints/_data ./checkpointsRuTaBERT$ sudo cp -r /var/lib/docker/volumes/rutabert_logs/_data ./logscheckpoint .logs/ répertoire ( training_results.csv , train.log , test.log , error_train.log , error_test.log ).RuTaBERT$ virtualenv venvou
RuTaBERT$ python -m virtualenv venvRuTaBERT$ sbatch run.slurmRuTaBERT$ squeuecheckpoint .logs/ répertoire ( train.log , test.log , error_train.log , error_test.log ). data/test .RuTaBERT$ ./download.sh table_wiseou
RuTaBERT$ ./download.sh column_wiseconfig.json .RuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 test.py 2> logs/error_test.loglogs/ répertoire ( test.log , error_test.log ). data/inference .RuTaBERT$ ./download.sh table_wiseou
RuTaBERT$ ./download.sh column_wiseconfig.jsonRuTaBERT$ source venv/bin/activate &&
pip install -r requirements.txt &&
python3 inference.pydata/inference/result.csv