Versión china 中文版
Pytorch-Lightning es una biblioteca muy conveniente. Se puede ver como una abstracción y envasado de Pytorch. Sus ventajas son una fuerte reutilización, fácil mantenimiento, lógica clara, etc. La desventaja es que es demasiado pesado y requiere bastante tiempo para aprender y comprender. Además, dado que vincula directamente el modelo y el código de capacitación, no es adecuado para proyectos reales con múltiples modelos y archivos de conjunto de datos. Lo mismo es cierto para el diseño del módulo de datos. El fuerte acoplamiento de cosas como DataLoader y conjuntos de datos personalizados también causa un problema similar: el mismo código se copia y se pegan de forma inelvisionada aquí y allá.
Después de mucha exploración y práctica, he resumido las siguientes plantillas, que también pueden ser una abstracción adicional de Pytorch-Lightning. En la primera versión, todo el contenido de la plantilla está en la carpeta raíz. Sin embargo, después de usarlo durante más de un mes, descubrí que las plantillas más especificadas para diferentes tipos de proyectos pueden aumentar la eficiencia de codificación. Por ejemplo, las tareas de clasificación y súper resolución tienen algunos puntos de demanda fijos. El código del proyecto se puede implementar más rápido modificando directamente plantillas especializadas, y también se han reducido algunos errores evitables.
** Actualmente, dado que esta sigue siendo una nueva biblioteca, solo hay estas dos plantillas. Sin embargo, más tarde, como lo aplico a otros proyectos, también se agregarán nuevas plantillas especializadas. Si ha utilizado esta plantilla para sus tareas (como PNL, GaN, reconocimiento de voz, etc.), puede enviar un PR para que pueda integrar su plantilla en la biblioteca para que más personas los usen. Si su tarea aún no está en la lista, comenzar desde la plantilla classification es una buena opción. Dado que la mayoría de la lógica y el código de las plantillas subyacentes son los mismos, esto se puede hacer muy rápidamente. **
Todos son bienvenidos a probar este conjunto de estilos de código. Es bastante conveniente reutilizar si está acostumbrado, y no es fácil volver al agujero. Se puede encontrar una explicación más detallada y una guía completa para Pytorch-Lightning en el blog de este artículo Zhihu.
root-
|-data
|-__init__.py
|-data_interface.py
|-xxxdataset1.py
|-xxxdataset2.py
|-...
|-model
|-__init__.py
|-model_interface.py
|-xxxmodel1.py
|-xxxmodel2.py
|-...
|-main.py
|-utils.py
No se necesita instalación. Ejecute directamente git clone https://github.com/miracleyoo/pytorch-lightning-template.git para clonarlo a su posición local. Elija su tipo de problema como classification y copie la plantilla correspondiente a su directorio de proyecto.
Tres son solo main.py y utils.py en el directorio raíz. El primero es la entrada del código, y el segundo es un archivo de soporte.
Hay un archivo __init__.py en la carpeta data y modle para convertirlos en paquetes. De esta manera, la importación se vuelve más fácil.
Cree una class DInterface(pl.LightningDataModule): en data_interface para funcionar como la interfaz de todos los diferentes archivos de conjunto de datos customados. La clase de conjunto de datos correspondiente se importa en la función __init__() . La instancia se realiza en la setup() , y se crean train_dataloader , val_dataloader , test_dataloader funciones.
Del mismo modo, la clase class MInterface(pl.LightningModule): se crean en model_interface para funcionar como la interfaz de todos los archivos de su modelo. La clase modelo correspondiente se importa en la función __init__() . Las únicas cosas que necesita modificar en la interfaz son las funciones como configure_optimizers , training_step , validation_step que controlan su propio proceso de entrenamiento. Una interfaz para todos los modelos, y la diferencia se maneja en Args.
main.py solo es responsable de las siguientes tareas:
Interface , puede agregar directamente un elemento de análisis en el archivo main.py Por ejemplo, hay un argumento de cadena llamado random_arg , puede agregar parser.add_argument('--random_arg', default='test', type=str) al archivo main.pycallback necesarias, como Auto-Save, Early Stop y LR Scheduler。MInterface , DInterface , Trainer 。Aleta.
Una cosa a la que debe prestar atención es, para permitir que MInterface y DInterface puedan analizar sus modelos y conjuntos de datos recién agregados automáticamente simplemente especificando el argumento --model_name y --dataset , utilizamos una caja de serpiente (como standard_net.py ) para el archivo modelo/conjunto de datos, y usamos el mismo contenido con Camel Case para el nombre de la clase, como StandardNet .
Lo mismo es cierto para la carpeta data .
Aunque esto parece restringir su nombre de modelos y conjuntos de datos, también puede hacer que su código sea más fácil de leer y comprender. Preste atención a este punto para evitar problemas de análisis.
Si usó esta plantilla y le resulta útil para su investigación, considere citar nuestro documento:
@article{ZHANG2023126388,
title = {Neuromorphic high-frequency 3D dancing pose estimation in dynamic environment},
journal = {Neurocomputing},
volume = {547},
pages = {126388},
year = {2023},
issn = {0925-2312},
doi = {https://doi.org/10.1016/j.neucom.2023.126388},
url = {https://www.sciencedirect.com/science/article/pii/S0925231223005118},
author = {Zhongyang Zhang and Kaidong Chai and Haowen Yu and Ramzi Majaj and Francesca Walsh and Edward Wang and Upal Mahbub and Hava Siegelmann and Donghyun Kim and Tauhidur Rahman},
keywords = {Event Camera, Dynamic Vision Sensor, Neuromorphic Camera, Simulator, Dataset, Deep Learning, Human Pose Estimation, 3D Human Pose Estimation, Technology-Mediated Dancing},
}
@InProceedings{Zhang_2022_WACV,
author = {Zhang, Zhongyang and Xu, Zhiyang and Ahmed, Zia and Salekin, Asif and Rahman, Tauhidur},
title = {Hyperspectral Image Super-Resolution in Arbitrary Input-Output Band Settings},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) Workshops},
month = {January},
year = {2022},
pages = {749-759}
}