종이 | 프로젝트
이것은 Pytorch 의 Palette : Image-to-Image 확산 모델 의 비공식적 구현이며, 주로 초고상 버전 Image-Super-Super-Super-via-etertative refinement에서 상속됩니다. 코드 템플릿은 다른 시드 프로젝트 인 Distributed-Pytorch-Template에서 나온 것입니다.
종이 설명과 관련된 몇 가지 구현 세부 사항이 있습니다.
Guided-Diffusion 에 사용 된 U-Net 아키텍처를 적응하여 샘플 품질을 상당히 향상시킵니다.DDPM 과 같은 저해상도 기능 (16 × 16)에서주의 메커니즘을 사용했습니다.Palette 에서 아핀 변환으로 포함시킵니다.Palette 에 설명 된 바와 같이 추론 중에 상수. 순서대로 다음 작업을 끝내려고 노력합니다.
시간 부족과 GPU 자원으로 인해 후속 실험은 불확실합니다.
DDPM 모델에는 중요한 계산 리소스가 필요 하며이 백서의 아이디어를 검증하기 위해 몇 가지 예제 모델 만 구축했습니다.
200 개의 에포크 및 930k 반복과 중심 마스크 및 불규칙한 마스크의 첫 100 개의 샘플이 포함 된 결과.
![]() | ![]() |
|---|
16 개의 에포크와 660k 반복과 중앙 마스크에서 몇 가지 선택된 샘플을 사용한 결과.
![]() | ![]() | ![]() | ![]() |
|---|---|---|---|
![]() | ![]() | ![]() | ![]() |
8 개의 에포크와 330k 반복의 결과와 몇몇 선택된 샘플은 끊임없이 샘플을 선택했습니다.
![]() | ![]() |
|---|
| 작업 | 데이터 세트 | EMA | 버팀대(-) | IS (+) |
|---|---|---|---|---|
| 센터링 마스크를 입력합니다 | Celeba-hq | 거짓 | 5.7873 | 3.0705 |
| 불규칙한 마스크로 입학 | Celeba-hq | 거짓 | 5.4026 | 3.1221 |
pip install - r requirements . txt| 데이터 세트 | 일 | 반복 | gpus × 일 × bs | URL |
|---|---|---|---|---|
| Celeba-hq | 입학 | 930K | 2 × 5 × 3 | 구글 드라이브 |
| 장소 2 | 입학 | 660K | 4 × 8 × 10 | 구글 드라이브 |
BS는 GPU 당 샘플 크기를 나타냅니다.
공식 버전과 약간 다를 수있는 Kaggle에서 대부분의 사람들을 얻을 수 있으며 공식 웹 사이트에서 다운로드 할 수도 있습니다.
우리는 교육 및 평가를 위해이 데이터 세트의 기본 부서를 사용합니다. 우리가 사용하는 파일 목록은 Celeba-HQ, Places2에서 찾을 수 있습니다.
자체 데이터를 준비한 후에는 데이터를 가리 키도록 해당 구성 파일을 수정해야합니다. 다음을 예로 들어보십시오.
" which_dataset " : { // import designated dataset using arguments
" name " : ["data.dataset", "InpaintDataset"], // import Dataset() class
" args " :{ // arguments to initialize dataset
" data_root " : " your data path " ,
" data_len " : -1,
" mask_mode " : " hybrid "
}
}, DataLoader 및 유효성 검사 분할 에 대한 더 많은 선택은 datasets 구성 파일의 일부에서 찾을 수 있습니다.
resume_state 의 구성 파일을 이전 체크 포인트의 디렉토리로 설정하십시오. 다음을 예로 들어보십시오.이 디렉토리에는 교육 상태와 저장된 모델이 포함되어 있습니다. " path " : { //set every part file path
" resume_state " : "experiments/inpainting_celebahq_220426_150122/checkpoint/100"
},model.py 의 load_everything 함수로 네트워크 레이블을 설정하십시오. 기본값은 네트워크 입니다. 튜토리얼 설정을 따르십시오. 최적화 및 모델은 각각 100.state 및 100_network.pth에서로드됩니다. netG_label = self . netG . __class__ . __name__
self . load_network ( network = self . netG , network_label = netG_label , strict = False ) python run . py - p train - c config / inpainting_celebahq . json 우리는 SR3 및 Guided Diffusion 에 사용 된 U-Net 백본을 테스트하고 Guided Diffusion 1은 현재 실험에서보다 강력한 성능을 가지고 있습니다. 백본 , 손실 및 메트릭 에 대한 더 많은 선택은 which_networks 구성 파일의 일부에서 찾을 수 있습니다.
python run . py - p test - c config / inpainting_celebahq . json지상 진실 이미지와 샘플 이미지를 저장하는 두 개의 폴더를 만들고 파일 이름은 서로 대응해야합니다.
스크립트 실행 :
python eval . py - s [ ground image path ] - d [ sample image path ]우리의 작업은 다음과 같은 이론적 인 작품을 기반으로합니다.
그리고 우리는 다음 프로젝트에서 많은 혜택을 받고 있습니다.