DASSL es una caja de herramientas de Pytorch desarrollada inicialmente para nuestro Proyecto Dominio Adaptive Ensemble Learning (DAEL) para apoyar la investigación en la adaptación y generalización del dominio, ya que en Dael estudiamos cómo unificar estos dos problemas en un solo marco de aprendizaje. Dado que la adaptación del dominio está estrechamente relacionada con el aprendizaje semi-supervisado, ambos estudian cómo explotar los datos no etiquetados, también incorporamos componentes que apoyan la investigación para este último.
¿Por qué el nombre "Dassl"? DASSL combina las iniciales de adaptación de dominio (DA) y aprendizaje semi-supervisado (SSL), que suena natural e informativo.
DASSL tiene un diseño modular e interfaces unificadas, lo que permite la prototipos rápidos y la experimentación de nuevos métodos DA/DG/SSL. Con DASSL, se puede implementar un nuevo método con solo unas pocas líneas de código. ¿No crees? Eche un vistazo a la carpeta del motor, que contiene las implementaciones de muchos métodos existentes (luego volverá y protagonizará este repositorio). :-)
Básicamente, DASSL es perfecto para investigar en las siguientes áreas:
Pero, gracias al diseño ordenado, DASSL también se puede usar como una base de código para desarrollar cualquier proyecto de aprendizaje profundo, como este. :-)
Un inconveniente de DASSL es que no es (todavía? Hmm) admite el entrenamiento multi-GPU distribuido (DASSL usa DataParallel para envolver un modelo, que es menos eficiente que DistributedDataParallel ).
No proporcionamos documentos detallados para DASSL, a diferencia de otro proyecto nuestro. Esto se debe a que DASSL se desarrolla con fines de investigación y, como investigador, creemos que es importante poder leer el código fuente y lo recomendamos que lo haga, definitivamente no porque seamos perezosos. :-)
v0.6.0 : hacer cfg.TRAINER.METHOD_NAME consistente con el nombre de la clase de método.v0.5.0 : cambios importantes realizados a transforms.py . 1) center_crop se convierte en una transformación predeterminada en las pruebas (aplicada después de cambiar el tamaño del borde más pequeño a un determinado tamaño para mantener la relación de aspecto de la imagen). 2) Para el entrenamiento, Resize(cfg.INPUT.SIZE) se desactiva cuando se usa random_crop o random_resized_crop . Estos cambios no harán ninguna diferencia en las transformaciones de entrenamiento utilizadas en los archivos de configuración existentes, ni en las transformaciones de prueba a menos que las imágenes sin procesar no estén al cuadrado (la única diferencia es que ahora se respeta la relación de aspecto de la imagen).v0.4.3 : Copie los atributos en self.dm (administrador de datos) a SimpleTrainer y haga que self.dm sea opcional, lo que significa que a partir de ahora puede crear cargadores de datos de cualquier fuente que desee en lugar de ser obligado a usar DataManager .v0.4.2 : Una actualización importante es establecer drop_last=is_train and len(data_source)>=batch_size al construir un cargador de datos para evitar 0 de longitud. DASSL ha implementado los siguientes métodos:
Adaptación de dominio de fuente única
Adaptación de dominio de múltiples fuentes
Generalización del dominio
Aprendizaje semi-supervisado
¡Siéntase libre de hacer un PR para agregar sus métodos aquí para facilitar que otros comparen!
Dassl admite los siguientes conjuntos de datos:
Adaptación de dominio
Generalización del dominio
Aprendizaje semi-supervisado
Asegúrese de que Conda se instale correctamente.
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -y -n dassl python=3.8
# Activate the environment
conda activate dassl
# Install torch (requires version >= 1.8.1) and torchvision
# Please refer to https://pytorch.org/ if you need a different cuda version
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
# Install dependencies
pip install -r requirements.txt
# Install this library (no need to re-build if the source code is modified)
python setup.py developSiga las instrucciones en los conjuntos de datos.md para preprocesar los conjuntos de datos.
La interfaz principal se implementa en tools/train.py , que básicamente hace
cfg = setup_cfg(args) donde args contiene la entrada de línea de comandos (ver tools/train.py para la lista de argumentos de entrada);trainer con build_trainer(cfg) que carga el conjunto de datos y construye un modelo de red neuronal profunda;trainer.train() para capacitar y evaluar el modelo.A continuación, proporcionamos un ejemplo para capacitar una línea de base exclusiva para la fuente en el popular conjunto de datos de adaptación de dominio, Office-31,
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains amazon
--target-domains webcam
--dataset-config-file configs/datasets/da/office31.yaml
--config-file configs/trainers/da/source_only/office31.yaml
--output-dir output/source_only_office31 $DATA denota la ubicación donde se instalan conjuntos de datos. --dataset-config-file carga la configuración común para el conjunto de datos (Office-31 en este caso), como el tamaño de la imagen y la arquitectura del modelo. --config-file carga la configuración específica del algoritmo, como los hiperparámetros y los parámetros de optimización.
Para usar múltiples fuentes, a saber, la tarea de adaptación de dominio de múltiples fuentes, uno solo necesita agregar más fuentes a --source-domains . Por ejemplo, para entrenar una línea de base exclusiva de origen en minidomainnet, uno puede hacer
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains clipart painting real
--target-domains sketch
--dataset-config-file configs/datasets/da/mini_domainnet.yaml
--config-file configs/trainers/da/source_only/mini_domainnet.yaml
--output-dir output/source_only_minidnDespués de que termine el entrenamiento, los pesos del modelo se guardarán en el directorio de salida especificado, junto con un archivo de registro y un archivo de tablero tensor para la visualización.
Para imprimir los resultados guardados en el archivo de registro (para que no sea necesario revisar exhaustivamente todos los archivos de registro y calcular la media/ETS por sí misma), puede usar tools/parse_test_res.py . La instrucción se puede encontrar en el código.
Para otros entrenadores, como MCD , puede establecer --trainer MCD mientras mantiene el archivo de configuración sin cambios, es decir, utilizando los mismos parámetros de entrenamiento que SourceOnly (en el caso más simple). Para modificar los hiper-parámetros en MCD, como N_STEP_F (número de pasos para actualizar el extractor de funciones), puede agregar TRAINER.MCD.N_STEP_F 4 a los argumentos de entrada existentes (de lo contrario, el valor predeterminado se usará). Alternativamente, puede crear un nuevo archivo de configuración .yaml para almacenar su configuración personalizada. Consulte aquí para obtener una lista completa de hiper-parametros específicos de algoritmo.
Las pruebas de modelo se pueden realizar usando --eval-only , que le pide al código que ejecute trainer.test() . También debe proporcionar el modelo capacitado y especificar qué archivo de modelo (es decir, guardado en qué época) usar. Por ejemplo, para usar model.pth.tar-20 guardado en output/source_only_office31/model , puede hacer
CUDA_VISIBLE_DEVICES=0 python tools/train.py
--root $DATA
--trainer SourceOnly
--source-domains amazon
--target-domains webcam
--dataset-config-file configs/datasets/da/office31.yaml
--config-file configs/trainers/da/source_only/office31.yaml
--output-dir output/source_only_office31_test
--eval-only
--model-dir output/source_only_office31
--load-epoch 20 Tenga en cuenta que --model-dir toma como entrada la ruta del directorio que se especificó en --output-dir en la etapa de entrenamiento.
Una buena práctica es pasar por dassl/engine/trainer.py para ser familiar con las clases de entrenadores base, que proporcionan funciones genéricas y bucles de capacitación. Para escribir una clase de entrenador para la adaptación de dominio o el aprendizaje semi-supervisado, la nueva clase puede subclase TrainerXU . Para la generalización del dominio, la nueva clase puede subclase TrainerX . En particular, TrainerXU y TrainerX difieren principalmente en si el uso de un cargador de datos para datos no etiquetados. Con las clases base, un nuevo entrenador solo puede necesitar implementar el método forward_backward() , que realiza el cálculo de pérdidas y la actualización del modelo. Consulte dassl/enigne/da/source_only.py por ejemplo.
backbone corresponde a un modelo de red neuronal convolucional que realiza una extracción de características. head (que es un módulo opcional) se monta en la parte posterior de backbone para su posterior procesamiento, lo que puede ser, por ejemplo, un MLP. backbone y head son bloques de construcción básicos para construir un SimpleNet() (ver dassl/engine/trainer.py ) que sirve como el modelo principal para una tarea. network contiene modelos de red neuronales personalizados, como un generador de imágenes.
Para agregar un nuevo módulo, a saber, una columna vertebral/cabeza/red, primero debe registrar el módulo utilizando el registry correspondiente, es decir, BACKBONE_REGISTRY para backbone , HEAD_REGISTRY para head y NETWORK_RESIGTRY para network . Tenga en cuenta que para una nueva backbone , requerimos que el modelo subclase Backbone como se define en dassl/modeling/backbone/backbone.py y especifique el atributo self._out_features .
Proporcionamos un ejemplo a continuación sobre cómo agregar una nueva backbone .
from dassl . modeling import Backbone , BACKBONE_REGISTRY
class MyBackbone ( Backbone ):
def __init__ ( self ):
super (). __init__ ()
# Create layers
self . conv = ...
self . _out_features = 2048
def forward ( self , x ):
# Extract and return features
@ BACKBONE_REGISTRY . register ()
def my_backbone ( ** kwargs ):
return MyBackbone () Luego, puede establecer MODEL.BACKBONE.NAME a my_backbone para usar su propia arquitectura. Para obtener más detalles, consulte el código fuente en dassl/modeling .
A continuación se muestra una estructura de código de ejemplo. Asegúrese de subclase DatasetBase y registre el conjunto de datos con @DATASET_REGISTRY.register() . Todo lo que necesita es cargar train_x , train_u (Opcional), val (opcional) y test , entre los cuales train_u y val podrían ser None o simplemente ignorarse. Cada una de estas variables contiene una lista de objetos Datum . Un objeto Datum (implementado aquí) contiene información para una sola imagen, como impath (String) y label (int).
from dassl . data . datasets import DATASET_REGISTRY , Datum , DatasetBase
@ DATASET_REGISTRY . register ()
class NewDataset ( DatasetBase ):
dataset_dir = ''
def __init__ ( self , cfg ):
train_x = ...
train_u = ... # optional, can be None
val = ... # optional, can be None
test = ...
super (). __init__ ( train_x = train_x , train_u = train_u , val = val , test = test )Le sugerimos que eche un vistazo al código de conjuntos de datos en algunos proyectos como este, que se basa en DASSL.
Nos gustaría compartir aquí nuestra investigación relevante para DASSL.
Si encuentra este código útil para su investigación, dé crédito al siguiente documento
@article{zhou2022domain,
title={Domain generalization: A survey},
author={Zhou, Kaiyang and Liu, Ziwei and Qiao, Yu and Xiang, Tao and Loy, Chen Change},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
@article{zhou2021domain,
title={Domain adaptive ensemble learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={IEEE Transactions on Image Processing},
volume={30},
pages={8008--8018},
year={2021},
publisher={IEEE}
}