Dassl ist eine Pytorch-Toolbox, die ursprünglich für unser Projektdomänen-Adaptive Ensemble Learning (DAEL) entwickelt wurde, um die Forschung in der Anpassung und Verallgemeinerung von Domänen zu unterstützen-da wir in DAEL untersuchen, wie diese beiden Probleme in einem einzigen Lernrahmen vereinen können. Angesichts der Tatsache, dass die Anpassung der Domänen eng mit dem semi-vortrimierten Lernen zusammenhängt-beide untersuchen, wie nicht veröffentlichte Daten ausgenommen werden können-wir integrieren auch Komponenten, die die Forschung für letztere unterstützen.
Warum der Name "Dassl"? Dassl kombiniert die Initialen der Domänenanpassung (DA) und des halbübergreifenden Lernens (SSL), was natürlich und informativ klingt.
Dassl verfügt über ein modulares Design und ein einheitliches Schnittstellen, das schnelle Prototyping und Experimentieren neuer DA/DG/SSL -Methoden ermöglicht. Mit DassL kann eine neue Methode mit nur wenigen Codezeilen implementiert werden. Glauben Sie nicht? Schauen Sie sich den Motorordner an, der die Implementierungen vieler vorhandener Methoden enthält (dann kommen Sie zurück und spielen dieses Repo). :-)
Grundsätzlich eignet sich dassl perfekt für die Forschung in den folgenden Bereichen:
Dank des ordentlichen Designs kann Dassl aber auch als Codebasis verwendet werden, um alle Deep -Learning -Projekte wie diese zu entwickeln. :-)
Ein Nachteil von DASSL ist, dass es nicht (noch? HMM) verteiltes Multi-GPU-Training unterstützt (DASSL verwendet DataParallel , um ein Modell zu wickeln, das weniger effizient ist als DistributedDataParallel ).
Im Gegensatz zu einem anderen Projekt werden wir keine detaillierten Dokumentationen für DASSL bereitstellen. Dies liegt daran, dass Dassl für Forschungszwecke entwickelt wurde und als Forscher es wichtig ist, Quellcode lesen zu können, und wir empfehlen Ihnen dringend-dies nicht, weil wir faul sind. :-)
v0.6.0 : Machen Sie cfg.TRAINER.METHOD_NAME im Einklang mit dem Namen der Methodenklassen.v0.5.0 : Wichtige Änderungen an transforms.py . 1) center_crop wird zu einer Standardtransformation in der Tests (angewendet nach der Größe der kleineren Kante zu einer bestimmten Größe, um das Bild -Seitenverhältnis zu halten). 2) Für das Training wird Resize(cfg.INPUT.SIZE) deaktiviert, wenn random_crop oder random_resized_crop verwendet wird. Diese Änderungen machen keinen Unterschied zu den in vorhandenen Konfigurationsdateien verwendeten Trainingstransformationen, noch zu den Testtransformationen, es sei denn, die Rohbilder sind nicht quadratisch (der einzige Unterschied besteht darin, dass das Bild -Seitenverhältnis jetzt respektiert wird).v0.4.3 : Kopieren Sie die Attribute in self.dm (Data Manager) in SimpleTrainer und machen Sie self.dm optional, dh von nun an können Sie Datenlader aus jeder Quelle erstellen, die Sie mögen, anstatt gezwungen zu werden, DataManager zu verwenden.v0.4.2 : Ein wichtiges Update ist das Festlegen drop_last=is_train and len(data_source)>=batch_size beim Erstellen eines Datenladers, um 0 Länge zu vermeiden. Dassl hat die folgenden Methoden implementiert:
Ein-Source-Domänenanpassung
Multi-Source-Domänenanpassung
Domänenverallgemeinerung
Semi-betriebliches Lernen
Fühlen Sie sich frei, eine PR zu machen, um Ihre Methoden hier hinzuzufügen, um es anderen zu erleichtern, den Benchmark zu erleichtern!
DassL unterstützt die folgenden Datensätze:
Domänenanpassung
Domänenverallgemeinerung
Semi-betriebliches Lernen
Stellen Sie sicher, dass Conda ordnungsgemäß installiert ist.
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -y -n dassl python=3.8
# Activate the environment
conda activate dassl
# Install torch (requires version >= 1.8.1) and torchvision
# Please refer to https://pytorch.org/ if you need a different cuda version
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
# Install dependencies
pip install -r requirements.txt
# Install this library (no need to re-build if the source code is modified)
python setup.py developBefolgen Sie die Anweisungen in Datasets.md, um die Datensätze vorzubereiten.
Die Hauptschnittstelle ist in tools/train.py implementiert, was im Grunde genommen tut
cfg = setup_cfg(args) wobei args die Befehlszeileneingabe enthält (siehe tools/train.py für die Liste der Eingabebedingungen);trainer mit build_trainer(cfg) der den Datensatz lädt und ein tiefer neuronales Netzwerkmodell erstellt.trainer.train() an, um das Modell auszubilden und zu bewerten.Im Folgenden geben wir ein Beispiel für die Schulung einer nur Quellenbasis auf dem beliebten Domänenanpassungsdatensatz Office-31,
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains amazon
--target-domains webcam
--dataset-config-file configs/datasets/da/office31.yaml
--config-file configs/trainers/da/source_only/office31.yaml
--output-dir output/source_only_office31 $DATA bezeichnet den Ort, an dem Datensätze installiert sind. --dataset-config-file lädt die gemeinsame Einstellung für den Datensatz (Office-31 in diesem Fall) wie Bildgröße und Modellarchitektur. --config-file lädt die algorithmisch-spezifische Einstellung wie Hyperparameter und Optimierungsparameter.
Um mehrere Quellen zu verwenden, nämlich die Multi-Source-Domänenanpassungsaufgabe, muss man nur weitere Quellen für --source-domains hinzufügen. Zum Beispiel kann man, um eine nur Quellenbasis auf Minidomainnet zu trainieren
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains clipart painting real
--target-domains sketch
--dataset-config-file configs/datasets/da/mini_domainnet.yaml
--config-file configs/trainers/da/source_only/mini_domainnet.yaml
--output-dir output/source_only_minidnNach dem Training werden die Modellgewichte zusammen mit einer Protokolldatei und einer Tensorboard -Datei zur Visualisierung unter dem angegebenen Ausgangsverzeichnis gespeichert.
Um die in der Protokolldatei gespeicherten Ergebnisse auszudrucken (Sie müssen daher nicht alle Protokolldateien durchgehen und den Mittelwert/die STD selbst berechnen), können Sie tools/parse_test_res.py verwenden. Die Anweisung kann im Code gefunden werden.
Für andere Trainer wie MCD können Sie --trainer MCD , während der Konfigurationsdatei unverändert festhalten und dieselben Trainingsparameter wie SourceOnly verwenden (im einfachsten Fall). Um die Hyperparameter in MCD wie N_STEP_F (Anzahl der Schritte zur Aktualisierung des Feature-Extraktors) zu TRAINER.MCD.N_STEP_F 4 . Alternativ können Sie eine neue .yaml -Konfigurationsdatei erstellen, um Ihre benutzerdefinierte Einstellung zu speichern. Eine vollständige Liste der algorithmisch spezifischen Hyperparameter finden Sie hier.
Modelltests können durch Verwendung --eval-only TEST, auf die der Code auffordert, trainer.test() auszuführen. Sie müssen auch das geschulte Modell bereitstellen und angeben, welche Modelldatei (dh gespeichert in welcher Epoche) verwendet werden soll. Zum Beispiel, um model.pth.tar-20 zu verwenden, gespeichert bei output/source_only_office31/model , können Sie dies tun
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains amazon
--target-domains webcam
--dataset-config-file configs/datasets/da/office31.yaml
--config-file configs/trainers/da/source_only/office31.yaml
--output-dir output/source_only_office31_test
--eval-only
--model-dir output/source_only_office31
--load-epoch 20 Beachten Sie, dass --model-dir als Eingabe den Verzeichnispfad nimmt, der in der Trainingsphase in --output-dir angegeben wurde.
Eine gute Praxis ist es, dassl/engine/trainer.py zu durchlaufen, um die Basistrainerklassen zu befriedigen, die generische Funktionen und Trainingsschleifen liefern. Um eine Trainerklasse für Domänenanpassung oder halbübergreifendes Lernen zu schreiben, kann die neue Klasse TrainerXU unterklassigen. Für die Domänenverallgemeinerung kann die neue Klasse TrainerX subklassen. Insbesondere unterscheiden sich TrainerXU und TrainerX hauptsächlich darin, ob die Verwendung eines Datenladers für nicht beliebige Daten verwendet wird. Bei den Basisklassen muss ein neuer Trainer möglicherweise nur die forward_backward() -Methode implementieren, die Verlustberechnung und Modellaktualisierung durchführt. Siehe dassl/enigne/da/source_only.py zum Beispiel.
backbone entspricht einem Faltungsmodell für neuronale Netzwerke, das die Feature -Extraktion durchführt. head (ein optionales Modul) ist zur weiteren Verarbeitung auf backbone montiert, was beispielsweise ein MLP sein kann. backbone und head sind grundlegende Bausteine für den Bau eines SimpleNet() (siehe dassl/engine/trainer.py ), das als Hauptmodell für eine Aufgabe dient. network enthält benutzerdefinierte neuronale Netzwerkmodelle wie einen Bildgenerator.
Um ein neues Modul hinzuzufügen, nämlich ein Backbone/Head/Network, müssen Sie das Modul zuerst mit der entsprechenden registry registrieren, dh BACKBONE_REGISTRY für backbone , HEAD_REGISTRY für head und NETWORK_RESIGTRY für network . Beachten Sie, dass wir für ein neues backbone das Modell für Backbone benötigen, wie in dassl/modeling/backbone/backbone.py definiert und das Attribut self._out_features angeben.
Wir geben unten ein Beispiel für das Hinzufügen eines neuen backbone .
from dassl . modeling import Backbone , BACKBONE_REGISTRY
class MyBackbone ( Backbone ):
def __init__ ( self ):
super (). __init__ ()
# Create layers
self . conv = ...
self . _out_features = 2048
def forward ( self , x ):
# Extract and return features
@ BACKBONE_REGISTRY . register ()
def my_backbone ( ** kwargs ):
return MyBackbone () Dann können Sie MODEL.BACKBONE.NAME an my_backbone festlegen, um Ihre eigene Architektur zu verwenden. Weitere Informationen finden Sie im Quellcode in dassl/modeling .
Eine Beispiel -Codestruktur ist unten angezeigt. Stellen Sie sicher, dass Sie DatasetBase unterklagen und den Datensatz unter @DATASET_REGISTRY.register() registrieren. Alles, was Sie brauchen, ist, train_x , train_u (optional), val (optional) und test zu laden, unter denen train_u und val None oder einfach ignoriert haben könnten. Jede dieser Variablen enthält eine Liste von Datum -Objekten. Ein Datum -Objekt (hier implementiert) enthält Informationen für ein einzelnes Bild wie impath (String) und label (int).
from dassl . data . datasets import DATASET_REGISTRY , Datum , DatasetBase
@ DATASET_REGISTRY . register ()
class NewDataset ( DatasetBase ):
dataset_dir = ''
def __init__ ( self , cfg ):
train_x = ...
train_u = ... # optional, can be None
val = ... # optional, can be None
test = ...
super (). __init__ ( train_x = train_x , train_u = train_u , val = val , test = test )Wir empfehlen Ihnen, sich den Datasets -Code in einigen Projekten wie diesen anzusehen, die auf DASSL basieren.
Wir möchten hier unsere für DassL relevante Forschung teilen.
Wenn Sie diesen Code für Ihre Forschung nützlich finden, geben Sie bitte das folgende Papier an,
@article{zhou2022domain,
title={Domain generalization: A survey},
author={Zhou, Kaiyang and Liu, Ziwei and Qiao, Yu and Xiang, Tao and Loy, Chen Change},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
@article{zhou2021domain,
title={Domain adaptive ensemble learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={IEEE Transactions on Image Processing},
volume={30},
pages={8008--8018},
year={2021},
publisher={IEEE}
}