该存储库提供了用于微调稳定扩散的代码。它是通过拥抱脸来从这个脚本改编的。用于微调的预培训模型来自Kerascv。要了解原始型号,请查看此文档。
此存储库中提供的代码仅用于研究目的。请查看本节,以了解有关潜在用例和限制的更多信息。
通过加载此模型,您可以在https://raw.githubusercontent.com/compvis/stable-diffusion/main/main/license上接受CreatiVeml Open Rail-M许可证。
如果您只是在寻找此存储库的随附资源,则是以下链接:
目录:
该存储库有一个姐妹存储库(Keras-SD服务),涵盖了稳定扩散的各种部署模式。
2023年1月13日更新:该项目在Google组织的有史以来的首次Keras社区奖比赛中获得了第二名。
按照拥抱面的原始脚本,此存储库还使用Pokemon数据集。但是它被重生以与tf.data更好地工作。数据集的再生版本在此处托管。查看该链接以获取更多详细信息。
finetune.py提供了微调代码。在进行培训之前,请确保您安装了依赖关系(请参阅requirements.txt )。
您可以通过运行python finetune.py来启动默认参数培训。运行python finetune.py -h以了解支持的命令行参数。您可以通过传递--mp标志来启用混合精确培训。
当您启动训练时,只有当当前损失低于前面的损失时,才会生成扩散模型检查点。
为了避免OOM和更快的训练,建议至少使用V100 GPU。我们使用了A100。
需要注意的一些重要细节:
培训详细信息:
我们对两种不同的分辨率进行了微调:256x256和512x512。我们仅通过这两种不同的分辨率改变了批处理大小和时代的数量。由于我们没有使用梯度积累,因此我们使用此代码段来得出时期的数量。
python finetune.py --batch_size 4 --num_epochs 577python finetune.py --img_height 512 --img_width 512 --batch_size 1 --num_epochs 72 --mp对于256x256分辨率,我们故意减少了时期的数量以节省计算时间。
微调重量:
您可以在此处找到微调的扩散模型权重。
此存储库中使用的默认宠物小精灵数据集附带以下结构:
pokemon_dataset/
data.csv
image_24.png
image_3.png
image_550.png
image_700.png
... data.csv看起来像:
只要您的自定义数据集遵循此结构,您就无需更改当前代码库中的任何内容,除了dataset_archive 。
如果您的数据集每个图像具有多个字幕,则可以在训练期间从每个图像的标题池中随机选择一个字幕。
根据数据集,您可能必须调整超参数。
import keras_cv
import matplotlib . pyplot as plt
from tensorflow import keras
IMG_HEIGHT = IMG_WIDTH = 512
def plot_images ( images , title ):
plt . figure ( figsize = ( 20 , 20 ))
for i in range ( len ( images )):
ax = plt . subplot ( 1 , len ( images ), i + 1 )
plt . title ( title )
plt . imshow ( images [ i ])
plt . axis ( "off" )
# We just have to load the fine-tuned weights into the diffusion model.
weights_path = keras . utils . get_file (
origin = "https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)
pokemon_model = keras_cv . models . StableDiffusion (
img_height = IMG_HEIGHT , img_width = IMG_WIDTH
)
pokemon_model . diffusion_model . load_weights ( weights_path )
# Generate images.
generated_images = pokemon_model . text_to_image ( "Yoda" , batch_size = 3 )
plot_images ( generated_images , "Fine-tuned on the Pokemon dataset" )您可以带上weights_path (应该与diffusion_model兼容)并重复使用代码片段。
查看此COLAB笔记本以播放推理代码。
最初,我们以256x256的分辨率微调了该模型。以下是一些结果以及与原始模型的结果进行比较。
| 图像 | 提示 |
|---|---|
| 尤达 | |
| 机器人猫与翅膀 | |
| Hello Kitty |
我们可以看到,微调模型比原始模型具有更稳定的输出。即使在美学上可以改善结果,但可以看到微调效果。另外,我们从拥抱面孔的脚本进行256x256分辨率(除了时期和批处理大小)中遵循了相同的超参数。使用更好的超参数,结果可能会有所改善。
对于512x512分辨率,我们观察到类似的东西。因此,我们尝试了unconditional_guidance_scale参数,并注意到将其设置为40(在确定其他参数时)时,结果会更好。
| 图像 | 提示 |
|---|---|
| 尤达 | |
| 机器人猫与翅膀 | |
| Hello Kitty |
注意:在512x512上进行微调仍在进行中。但是,在没有分布式训练和梯度积累的情况下,完成一个时代需要大量时间。以上结果来自60个时期后得出的检查点。
Lambda Labs有了类似的食谱(但经过培训以获得更优化的步骤),展示了令人惊叹的结果。