Implémentation simple pytorch de stylegan2 basé sur https://arxiv.org/abs/1912.04958 qui peut être complètement formé à partir de la ligne de commande, aucun codage nécessaire.
Voici quelques fleurs qui n'existent pas.
Ces mains non plus
Ni ces villes
Ni ces célébrités (formées par @yoniker)
Vous aurez besoin d'une machine avec un GPU et CUDA installées. Puis pip installe le package comme celui-ci
$ pip install stylegan2_pytorchSi vous utilisez une machine Windows, les commandes suivantes fonctionnent.
$ conda install pytorch torchvision -c python
$ pip install stylegan2_pytorch$ stylegan2_pytorch --data /path/to/images C'est ça. Les exemples d'images seront enregistrés sur results/default et les modèles seront enregistrés périodiquement sur models/default .
Vous pouvez spécifier le nom de votre projet avec
$ stylegan2_pytorch --data /path/to/images --name my-project-nameVous pouvez également spécifier l'emplacement où les résultats intermédiaires et les points de contrôle du modèle doivent être stockés avec
$ stylegan2_pytorch --data /path/to/images --name my-project-name --results_dir /path/to/results/dir --models_dir /path/to/models/dir Vous pouvez augmenter la capacité du réseau (qui par défaut 16 ) pour améliorer les résultats de la génération, au prix de plus de mémoire.
$ stylegan2_pytorch --data /path/to/images --network-capacity 256 Par défaut, si la formation est coupée, elle reprendra automatiquement à partir du dernier fichier de contrôle. Si vous souhaitez redémarrer avec de nouveaux paramètres, ajoutez simplement un new drapeau
$ stylegan2_pytorch --new --data /path/to/images --name my-project-name --image-size 512 --batch-size 1 --gradient-accumulate-every 16 --network-capacity 10Une fois l'entraînement terminé, vous pouvez générer des images à partir de votre dernier point de contrôle comme ainsi.
$ stylegan2_pytorch --generatePour générer une vidéo d'une interpolation à travers deux points aléatoires dans l'espace latent.
$ stylegan2_pytorch --generate-interpolation --interpolation-num-steps 100Pour sauver chaque cadre individuel de l'interpolation
$ stylegan2_pytorch --generate-interpolation --save-framesSi un point de contrôle précédent contenait un meilleur générateur, (ce qui se produit souvent lorsque les générateurs commencent à se dégrader vers la fin de la formation), vous pouvez charger à partir d'un point de contrôle précédent avec un autre drapeau
$ stylegan2_pytorch --generate --load-from {checkpoint number} Une technique utilisée dans Stylegan et Biggan tronque les valeurs latentes afin que leurs valeurs se rapprochent de la moyenne. Plus la valeur de troncature est petite, meilleur sera les échantillons au prix de la variété de l'échantillon. Vous pouvez contrôler cela avec le --trunc-psi , où les valeurs se situent généralement entre 0.5 et 1 . Il est défini à 0.75 par défaut
$ stylegan2_pytorch --generate --trunc-psi 0.5Si vous avez une machine avec plusieurs GPU, le référentiel offre un moyen de les utiliser tous pour la formation. Avec plusieurs GPU, chaque lot sera divisé uniformément entre les GPU disponibles. Par exemple, pour 2 GPU, avec une taille de lot de 32, chaque GPU verra 16 échantillons.
Il vous suffit d'ajouter un drapeau --multi-gpus , tout le reste est pris en charge. Si vous souhaitez vous limiter à des GPU spécifiques, vous pouvez utiliser la variable d'environnement CUDA_VISIBLE_DEVICES pour contrôler quels appareils peuvent être utilisés. (Ex. CUDA_VISIBLE_DEVICES=0,2,3 uniquement les appareils 0, 2, 3 sont disponibles)
$ stylegan2_pytorch --data ./data --multi-gpus --batch-size 32 --gradient-accumulate-every 1Dans le passé, Gans avait besoin de beaucoup de données pour apprendre à bien générer. Le modèle des faces a pris des images de haute qualité 70k de Flickr, par exemple.
Cependant, au cours du mois de mai 2020, les chercheurs du monde entier ont convergé indépendamment une technique simple pour réduire ce nombre à aussi bas que 1-2K . Cette simple idée était d'augmenter de manière différente toutes les images, générées ou réelles, en entrant dans le discriminateur pendant la formation.
Si l'on devait augmenter à une probabilité suffisamment basse, les augmentations ne se «fuient» pas dans les générations.
Dans le réglage des données faibles, vous pouvez utiliser la fonction avec un indicateur simple.
# find a suitable probability between 0. -> 0.7 at maximum
$ stylegan2_pytorch --data ./data --aug-prob 0.25 Par défaut, les augmentations utilisées sont translation et cutout . Si vous souhaitez ajouter color , vous pouvez le faire avec l'argument --aug-types .
# make sure there are no spaces between items!
$ stylegan2_pytorch --data ./data --aug-prob 0.25 --aug-types [translation,cutout,color]Vous pouvez le personnaliser à n'importe quelle combinaison des trois que vous souhaitez. Le code d'augmentation différenciable a été copié et légèrement modifié à partir d'ici.
Pendant aussi longtemps que possible jusqu'à ce que le jeu contradictoire entre les deux filets neuronaux s'effondre (nous appelons cette divergence). Par défaut, le nombre d'étapes de formation est défini sur 150000 pour les images 128x128, mais vous voudrez certainement que ce nombre soit plus élevé si le GAN ne diverge pas d'ici la fin de la formation, ou si vous vous entraînez à une résolution plus élevée.
$ stylegan2_pytorch --data ./data --image-size 512 --num-train-steps 1000000Ce cadre vous permet également d'ajouter une forme efficace d'auto-agencement aux couches désignées du discriminateur (et la couche symétrique du générateur), ce qui améliorera considérablement les résultats. Plus vous pouvez vous permettre d'attention, mieux c'est!
# add self attention after the output of layer 1
$ stylegan2_pytorch - - data . / data - - attn - layers 1 # add self attention after the output of layers 1 and 2
# do not put a space after the comma in the list!
$ stylegan2_pytorch - - data . / data - - attn - layers [ 1 , 2 ]Formation sur des images transparentes
$ stylegan2_pytorch --data ./transparent/images/path --transparentPlus vous avez de mémoire GPU, plus la génération d'images sera grande et meilleure. Nvidia a recommandé d'avoir jusqu'à 16 Go pour la formation des images 1024x1024. Si vous en avez moins, il y a quelques paramètres avec lesquels vous pouvez jouer afin que le modèle s'adapte.
$ stylegan2_pytorch --data /path/to/data
--batch-size 3
--gradient-accumulate-every 5
--network-capacity 16 Taille du lot - Vous pouvez réduire la batch-size à 1, mais vous devez augmenter le gradient-accumulate-every en conséquence afin que le mini-lot que le réseau voit ne soit pas trop petit. Cela peut être déroutant pour un profane, donc je vais réfléchir à la façon dont j'automatiserais le choix de gradient-accumulate-every l'avenir.
Capacité du réseau - Vous pouvez réduire la capacité du réseau neuronal pour réduire les exigences de la mémoire. Sachez simplement que cela s'est avéré dégrader les performances de la génération.
Si rien de tout cela ne fonctionne, vous pouvez vous contenter d'un GAn «léger», ce qui vous permettra de compromettre la qualité de s'entraîner à de plus grandes résolutions dans un délai raisonnable.
Vous trouverez ci-dessous quelques étapes qui peuvent être utiles pour le déploiement à l'aide des services Web d'Amazon. Pour l'utiliser, vous devrez provisionner une instance EC2 soutenue par GPU. Un type d'instance approprié proviendrait d'une série P2 ou P3. J'ai (Iboates) essayé un P2.xlarge (l'option la moins chère) et c'était assez lent, plus lent en fait que l'utilisation de Google Colab. Les types d'instance plus puissants peuvent être meilleurs mais ils sont plus chers. Vous pouvez en savoir plus à leur sujet ici.
sudo snap install aws-cli --classic
aws configureVous devrez ensuite saisir vos clés d'accès AWS, que vous pouvez récupérer à partir de la console de gestion sous la console de gestion AWS> Profil> Mes informations d'identification de sécurité> Clés d'accès
Ensuite, exécutez ces commandes, ou peut-être les mettre dans un script de shell et exécuter cela:
mkdir data
curl -O https://bootstrap.pypa.io/get-pip.py
sudo apt-get install python3-distutils
python3 get-pip.py
pip3 install stylegan2_pytorch
export PATH= $PATH :/home/ubuntu/.local/bin
aws s3 sync s3:// < Your bucket name > ~ /data
cd data
tar -xf ../train.tar.gz Maintenant, vous devriez pouvoir vous entraîner en appelant simplement stylegan2_pytorch [args] .
Notes:
screen afin qu'il ne se termine pas une fois que vous vous êtes déconnecté de la session SSH. Grâce à GetSeclectic, vous pouvez maintenant calculer périodiquement le score FID! Encore une fois, rendu super simple avec un argument supplémentaire, comme indiqué ci-dessous.
Tout d'abord, installez le package pytorch_fid
$ pip install pytorch-fidSuivi de
$ stylegan2_pytorch --data ./data --calculate-fid-every 5000 Les résultats de FID seront enregistrés à ./results/{name}/fid_scores.txt
Si vous souhaitez exemple d'échantillonnage d'images par programme, vous pouvez le faire avec la classe ModelLoader simple suivante.
import torch
from torchvision . utils import save_image
from stylegan2_pytorch import ModelLoader
loader = ModelLoader (
base_dir = '/path/to/directory' , # path to where you invoked the command line tool
name = 'default' # the project name, defaults to 'default'
)
noise = torch . randn ( 1 , 512 ). cuda () # noise
styles = loader . noise_to_styles ( noise , trunc_psi = 0.7 ) # pass through mapping network
images = loader . styles_to_images ( styles ) # call the generator on intermediate style vectors
save_image ( images , './sample.jpg' ) # save your images, or do whatever you desirePour enregistrer les pertes à un tracker d'expérience open source (AIM), il vous suffit de passer un drapeau supplémentaire comme ainsi.
$ stylegan2_pytorch --data ./data --logEnsuite, vous devez vous assurer que Docker est installé. Suivant les instructions à AIM, vous exécutez ce qui suit dans votre terminal.
$ aim upEnsuite, ouvrez votre navigateur à l'adresse et vous devriez voir

Un nouvel article a produit des preuves que, en faisant simplement naître les contributions du gradient d'échantillons jugés faux par le discriminateur, le générateur apprend beaucoup mieux, atteignant un nouvel état de l'art.
$ stylegan2_pytorch - - data . / data - - top - k - trainingLe gamma est un calendrier de décroissance qui diminue lentement le topk de la taille totale du lot à la fraction cible de 50% (également hyperparamètre modifiable).
$ stylegan2_pytorch - - data . / data - - top - k - training - - generate - top - k - frac 0.5 - - generate - top - k - gamma 0.99 Un article récent a rapporté des résultats améliorés si les représentations intermédiaires du discriminateur sont quantifiées vectorielles. Bien que je n'ai pas remarqué de changements dramatiques, j'ai décidé d'ajouter cela comme une fonctionnalité, afin que d'autres esprits puissent enquêter. Pour utiliser, vous devez spécifier la couche (s) que vous souhaitez quantifier vectoriel. La taille du dictionnaire par défaut est 256 et est également accordable.
# feature quantize layers 1 and 2, with a dictionary size of 512 each
# do not put a space after the comma in the list!
$ stylegan2_pytorch - - data . / data - - fq - layers [ 1 , 2 ] - - fq - dict - size 512J'ai essayé l'apprentissage contrastif sur le discriminateur (en étape avec la formation habituelle de GaN) et j'ai peut-être observé une amélioration de la stabilité et de la qualité des résultats finaux. Vous pouvez activer cette fonctionnalité expérimentale avec un indicateur simple comme indiqué ci-dessous.
$ stylegan2_pytorch - - data . / data - - cl - regCeci a été proposé dans le papier GaN relativiste pour stabiliser la formation. J'ai eu des résultats mitigés, mais j'inclurai la fonctionnalité pour ceux qui souhaitent l'expérimenter.
$ stylegan2_pytorch - - data . / data - - rel - disc - loss Par défaut, l'architecture Stylegan stylise un bloc 4x4 appris constant car il est progressivement échantillonné. Il s'agit d'une caractéristique expérimentale qui le fait, donc le bloc 4x4 est appris du vecteur de style w à la place.
$ stylegan2_pytorch - - data . / data - - no - constUn article récent a proposé qu'une nouvelle perte contrastée entre les logites réelles et fausses puisse améliorer la qualité par rapport à d'autres types de pertes. (La valeur par défaut de ce référentiel est la perte de charnière et le papier montre une légère amélioration)
$ stylegan2_pytorch - - data . / data - - dual - contrast - loss STYLEGAN2 + DISCRIMINATEUR UNET
J'ai obtenu de très bons résultats avec un discriminateur de l'ONU, mais le changement architectural était trop important pour s'adapter en option dans ce référentiel. Si vous visez la perfection, n'hésitez pas à l'essayer.
Si vous souhaitez que je donne le traitement royal à une autre architecture GAN (Biggan), n'hésitez pas à contacter mon e-mail. Heureux d'entendre votre terrain.
Merci à Matthew Mann pour son port simple inspirant pour Tensorflow 2.0
@article { Karras2019stylegan2 ,
title = { Analyzing and Improving the Image Quality of {StyleGAN} } ,
author = { Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila } ,
journal = { CoRR } ,
volume = { abs/1912.04958 } ,
year = { 2019 } ,
} @misc { zhao2020feature ,
title = { Feature Quantization Improves GAN Training } ,
author = { Yang Zhao and Chunyuan Li and Ping Yu and Jianfeng Gao and Changyou Chen } ,
year = { 2020 }
} @misc { chen2020simple ,
title = { A Simple Framework for Contrastive Learning of Visual Representations } ,
author = { Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton } ,
year = { 2020 }
} @article {,
title = { Oxford 102 Flowers } ,
author = { Nilsback, M-E. and Zisserman, A., 2008 } ,
abstract = { A 102 category dataset consisting of 102 flower categories, commonly occuring in the United Kingdom. Each class consists of 40 to 258 images. The images have large scale, pose and light variations. }
} @article { afifi201911k ,
title = { 11K Hands: gender recognition and biometric identification using a large dataset of hand images } ,
author = { Afifi, Mahmoud } ,
journal = { Multimedia Tools and Applications }
} @misc { zhang2018selfattention ,
title = { Self-Attention Generative Adversarial Networks } ,
author = { Han Zhang and Ian Goodfellow and Dimitris Metaxas and Augustus Odena } ,
year = { 2018 } ,
eprint = { 1805.08318 } ,
archivePrefix = { arXiv }
} @article { shen2019efficient ,
author = { Zhuoran Shen and
Mingyuan Zhang and
Haiyu Zhao and
Shuai Yi and
Hongsheng Li } ,
title = { Efficient Attention: Attention with Linear Complexities } ,
journal = { CoRR } ,
year = { 2018 } ,
url = { http://arxiv.org/abs/1812.01243 } ,
} @article { zhao2020diffaugment ,
title = { Differentiable Augmentation for Data-Efficient GAN Training } ,
author = { Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song } ,
journal = { arXiv preprint arXiv:2006.10738 } ,
year = { 2020 }
} @misc { zhao2020image ,
title = { Image Augmentations for GAN Training } ,
author = { Zhengli Zhao and Zizhao Zhang and Ting Chen and Sameer Singh and Han Zhang } ,
year = { 2020 } ,
eprint = { 2006.02595 } ,
archivePrefix = { arXiv }
} @misc { karras2020training ,
title = { Training Generative Adversarial Networks with Limited Data } ,
author = { Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila } ,
year = { 2020 } ,
eprint = { 2006.06676 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
} @misc { jolicoeurmartineau2018relativistic ,
title = { The relativistic discriminator: a key element missing from standard GAN } ,
author = { Alexia Jolicoeur-Martineau } ,
year = { 2018 } ,
eprint = { 1807.00734 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
} @misc { sinha2020topk ,
title = { Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples } ,
author = { Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena } ,
year = { 2020 } ,
eprint = { 2002.06224 } ,
archivePrefix = { arXiv } ,
primaryClass = { stat.ML }
} @misc { yu2021dual ,
title = { Dual Contrastive Loss and Attention for GANs } ,
author = { Ning Yu and Guilin Liu and Aysegul Dundar and Andrew Tao and Bryan Catanzaro and Larry Davis and Mario Fritz } ,
year = { 2021 } ,
eprint = { 2103.16748 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}