論文|プロジェクト
これは、 Pytorchによる画像間拡散モデルのパレットの非公式の実装であり、主にその超解像度バージョンの画像スーパー解像度-Via-oterative-Repitingから継承されています。コードテンプレートは、私の別のシードプロジェクトである分散-Pytorch-Templateのものです。
紙の説明には、いくつかの実装の詳細があります。
Guided-Diffusionに使用されるU-Netアーキテクチャを適応させました。これにより、サンプル品質が大幅に向上しました。DDPMのような低解像度の特徴(16×16)で注意メカニズムを使用しました。Paletteで、アフィン変換で埋め込みました。Paletteで説明されているように、推論中に定数に。 私は次のようなタスクに従うようにしています:
時間の不足とGPUリソースのために、フォローアップ実験は不確かです。
DDPMモデルには重要な計算リソースが必要であり、このペーパーのアイデアを検証するためのモデルの例をいくつか構築しました。
結果200エポックと930kの反復、およびセンタリングマスクと不規則なマスクの最初の100個のサンプルがあります。
![]() | ![]() |
|---|
結果16個のエポックと660kの反復、およびいくつかの選択されたサンプルがセンタリングマスクで選ばれました。
![]() | ![]() | ![]() | ![]() |
|---|---|---|---|
![]() | ![]() | ![]() | ![]() |
8つのエポックと330kの反復、およびいくつかの選択されたサンプルが登場した結果。
![]() | ![]() |
|---|
| タスク | データセット | エマ | fid( - ) | IS(+) |
|---|---|---|---|---|
| センターリングマスクでの開始 | Celeba-hq | 間違い | 5.7873 | 3.0705 |
| 不規則なマスクで開始します | Celeba-hq | 間違い | 5.4026 | 3.1221 |
pip install - r requirements . txt| データセット | タスク | 反復 | gpus×days×bs | URL |
|---|---|---|---|---|
| Celeba-hq | インパインティング | 930K | 2×5×3 | Googleドライブ |
| 場所2 | インパインティング | 660K | 4×8×10 | Googleドライブ |
BSは、GPUあたりのサンプルサイズを示します。
それらのほとんどはKaggleから入手できます。これは公式バージョンとはわずかに異なる場合があり、公式Webサイトからダウンロードすることもできます。
トレーニングと評価には、これらのデータセットのデフォルト分割を使用します。使用するファイルリストは、celeba-hq、places2にあります。
独自のデータを準備した後、対応する構成ファイルを変更してデータを指す必要があります。例として以下を取り上げてください。
" which_dataset " : { // import designated dataset using arguments
" name " : ["data.dataset", "InpaintDataset"], // import Dataset() class
" args " :{ // arguments to initialize dataset
" data_root " : " your data path " ,
" data_len " : -1,
" mask_mode " : " hybrid "
}
}, Dataloaderと検証の分割に関するより多くの選択肢は、構成ファイルのdatasetsの一部にもあります。
resume_stateの設定ファイルを以前のチェックポイントのディレクトリに設定します。例として、このディレクトリにはトレーニング状態と保存されたモデルが含まれています。 " path " : { //set every part file path
" resume_state " : "experiments/inpainting_celebahq_220426_150122/checkpoint/100"
},load_everything function of model.pyで設定します。デフォルトはネットワークです。チュートリアル設定に従って、オプティマイザーとモデルはそれぞれ100.Stateと100_Network.pthからロードされます。 netG_label = self . netG . __class__ . __name__
self . load_network ( network = self . netG , network_label = netG_label , strict = False ) python run . py - p train - c config / inpainting_celebahq . json SR3で使用されるU-NETバックボーンとGuided Diffusionをテストし、 Guided Diffusionは、現在の実験でより堅牢なパフォーマンスを持っています。バックボーン、損失、およびメトリックに関するより多くの選択肢は、 which_networks configureファイルの一部で見つけることができます。
python run . py - p test - c config / inpainting_celebahq . jsonグラウンドトゥルース画像とサンプル画像を保存する2つのフォルダーを作成すると、ファイル名が互いに対応する必要があります。
スクリプトを実行します:
python eval . py - s [ ground image path ] - d [ sample image path ]私たちの仕事は、次の理論的作品に基づいています。
そして、私たちは次のプロジェクトから多くの利益を得ています。