Arxiv | Bibtex
Desenvolvemos uma nova abordagem para a imagem que faz um trabalho melhor em reproduzir regiões cheias exibindo detalhes finos inspirados em nossa compreensão de como os artistas funcionam: linhas primeiro, colorir a seguir . Propomos um modelo adversário de dois estágios EdGeconnect que compreende um gerador de borda seguido por uma rede de conclusão de imagem. O gerador de borda alucina as bordas da região ausente (regular e irregular) da imagem, e a rede de conclusão da imagem preenche as regiões ausentes usando bordas alucinadas como a priori. A descrição detalhada do sistema pode ser encontrada em nosso artigo.

git clone https://github.com/knazeri/edge-connect.git
cd edge-connectpip install -r requirements.txtUtilizamos conjuntos de dados de visualização de rua de Ploces2, Celeba e Paris. Para treinar um modelo no conjunto de dados completo, faça o download dos conjuntos de dados de sites oficiais.
Após o download, execute scripts/flist.py para gerar listas de arquivos de trem, teste e validação. Por exemplo, para gerar a lista de arquivos do conjunto de treinamento no conjunto de dados do Ploces2 Run:
mkdir datasets
python ./scripts/flist.py --path path_to_places2_train_set --output ./datasets/places_train.flistNosso modelo é treinado no conjunto de dados de máscara irregular fornecida por Liu et al. Você pode baixar o conjunto de dados de máscara irregular disponível publicamente em seu site.
Como alternativa, você pode baixar o conjunto de dados de máscara irregular de desenho rápido de Karim Iskakov, que é uma combinação de 50 milhões de tacadas desenhadas pela mão humana.
Use scripts/flist.py para gerar listas de arquivos de máscaras de trem, teste e validação, conforme explicado acima.
Faça o download dos modelos pré-treinados usando os seguintes links e copie-os no diretório ./checkpoints .
Places2 | Celeba | Paris-Streetview
Como alternativa, você pode executar o seguinte script para baixar automaticamente os modelos pré-treinados:
bash ./scripts/download_model.sh Para treinar o modelo, crie um arquivo config.yaml semelhante ao arquivo de configuração de exemplo e copie -o no seu diretório de pontos de verificação. Leia o guia de configuração para obter mais informações sobre a configuração do modelo.
O edgeconnect é treinado em três estágios: 1) Treinando o modelo de borda, 2) Treinando o modelo de pintura e 3) treinando o modelo conjunto. Para treinar o modelo:
python train.py --model [stage] --checkpoints [path to checkpoints] Por exemplo, para treinar o modelo de borda no conjunto de dados do Place2 em ./checkpoints/places2 diretório:
python train.py --model 1 --checkpoints ./checkpoints/places2 A convergência do modelo difere do conjunto de dados para o conjunto de dados. Por exemplo, o conjunto de dados Places2 converge em uma das duas épocas, enquanto conjuntos de dados menores como o Celeba exigem quase 40 épocas para convergir. Você pode definir o número de iterações de treinamento alterando o valor MAX_ITERS no arquivo de configuração.
Para testar o modelo, crie um arquivo config.yaml semelhante ao arquivo de configuração de exemplo e copie -o no seu diretório de pontos de verificação. Leia o guia de configuração para obter mais informações sobre a configuração do modelo.
Você pode testar o modelo em todos os três estágios: 1) Modelo de borda, 2) modelo de pintura e 3) modelo de articulação. Em cada caso, você precisa fornecer uma imagem de entrada (imagem com uma máscara) e um arquivo de máscara em escala de cinza. Certifique -se de que o arquivo de máscara cubra toda a região da máscara na imagem de entrada. Para testar o modelo:
python test.py
--model [stage]
--checkpoints [path to checkpoints]
--input [path to input directory or file]
--mask [path to masks directory or mask file]
--output [path to the output directory] Fornecemos alguns exemplos de teste no diretório ./examples . Faça o download dos modelos pré-treinados e execute:
python test.py
--checkpoints ./checkpoints/places2
--input ./examples/places2/images
--mask ./examples/places2/masks
--output ./checkpoints/results Este script incluirá todas as imagens em ./examples/places2/images usando suas máscaras correspondentes no diretório ./examples/places2/mask e salva os resultados no diretório ./checkpoints/results . Por padrão, o script test.py é executado no estágio 3 ( --model=3 ).
Para avaliar o modelo, você precisa primeiro executar o modelo no modo de teste no seu conjunto de validação e salvar os resultados no disco. Fornecemos um utilitário ./scripts/metrics.py para avaliar o modelo usando o PSNR, SSIM e Erro A absoluto médio:
python ./scripts/metrics.py --data-path [path to validation set] --output-path [path to model output] Para medir a distância de Frécchet Inception (pontuação fid) ./scripts/fid_score.py . Utilizamos a implementação de Pytorch do FID daqui, que usa os pesos pré -tenhados do modelo de início de Pytorch.
python ./scripts/fid_score.py --path [path to validation, path to model output] --gpu [GPU id to use] Por padrão, usamos o detector de borda astutas para extrair informações de borda das imagens de entrada. Se você deseja treinar o modelo com uma detecção externa de arestas (detecção de borda aninhada holisticamente, por exemplo), é necessário gerar mapas de borda para todos os conjuntos de treinamento/teste como pré-processamento e suas listas de arquivos correspondentes usando scripts/flist.py conforme explicado acima. Verifique se os nomes dos arquivos e a estrutura do diretório correspondem aos seus conjuntos de treinamento/teste. Você pode mudar para a detecção de borda externa especificando EDGE=2 no arquivo de configuração.
A configuração do modelo é armazenada em um arquivo config.yaml no seu diretório de pontos de verificação. As tabelas a seguir fornecem a documentação para todas as opções disponíveis no arquivo de configuração:
| Opção | Descrição |
|---|---|
| MODO | 1: trem, 2: teste, 3: avaliação |
| MODELO | 1: Modelo de borda, 2: Modelo de Paint, 3: Modelo de Inchanha de Edge, 4: Modelo de Junto |
| MÁSCARA | 1: bloco aleatório, 2: metade, 3: externo, 4: externo + bloco aleatório, 5: externo + bloco aleatório + metade |
| BORDA | 1: Canny, 2: Externo |
| NMS | 0: sem suposição não max, 1: não-supressão de max nas bordas externas |
| SEMENTE | semente de gerador de números aleatórios |
| GPU | Lista de IDs de GPU, lista separada por vírgula, por exemplo, [0,1] |
| DEPURAR | 0: Sem depuração, 1: Modo de depuração |
| Detalhado | 0: sem verbose, 1: Saída estatística detalhada no console de saída |
| Opção | Descrição |
|---|---|
| TRIN_FLIST | Arquivo de texto contendo Lista de arquivos do conjunto de treinamento |
| Val_flist | Arquivo de texto que contém a lista de arquivos do conjunto de validação |
| Test_flist | Arquivo de texto que contém a lista de arquivos do conjunto de testes |
| TRIN_EDDE_FLIST | Arquivo de texto que contém treinamento Conjunto de arquivos externos Lista de arquivos (apenas com Edge = 2) |
| Val_edge_flist | Arquivo de texto contendo validação Defina a lista de arquivos externos (somente com Edge = 2) |
| Test_edge_flist | Arquivo de texto que contém o conjunto de testes Lista de arquivos externos (somente com Edge = 2) |
| TRIN_MASK_FLIST | Arquivo de texto que contém treinamento conjunto de arquivos de máscaras (apenas com máscara = 3, 4, 5) |
| Val_mask_flist | Arquivo de texto contendo validação Conjunto de arquivos de máscaras (somente com máscara = 3, 4, 5) |
| Test_mask_flist | Arquivo de texto que contém o conjunto de máscaras Lista de arquivos de máscaras (apenas com máscara = 3, 4, 5) |
| Opção | Padrão | Descrição |
|---|---|---|
| Lr | 0,0001 | taxa de aprendizado |
| D2g_lr | 0.1 | Taxa de aprendizado discriminador/gerador |
| Beta1 | 0,0 | Adam Optimizer beta1 |
| Beta2 | 0,9 | Adam Optimizer beta2 |
| Batch_size | 8 | Tamanho do lote de entrada |
| Input_size | 256 | Tamanho da imagem de entrada para treinamento. (0 para o tamanho original) |
| Sigma | 2 | Desvio padrão do filtro gaussiano usado no detector de arestas chuteiras (0: Random, -1: sem vantagem) |
| Max_iters | 2e6 | Número máximo de iterações para treinar o modelo |
| Edge_threshold | 0,5 | limiar de detecção de borda (0-1) |
| L1_loss_weight | 1 | L1 Perda de peso |
| Fm_loss_weight | 10 | Peso da perda de correspondência de recursos |
| Style_loss_weight | 1 | peso de perda de estilo |
| Content_loss_weight | 1 | Peso da perda perceptiva |
| Inpaint_adv_loss_weight | 0,01 | Peso da perda adversária |
| Gan_loss | Nsgan | NSGAN : GaN não saturador, Lsgan : Melhores quadrados Gan, dobradiça : Gane de perda de dobradiça |
| Gan_pool_size | 0 | Imagens falsas Tamanho do pool |
| Save_interval | 1000 | quantas iterações esperar antes de salvar o modelo (0: nunca) |
| Eval_Interval | 0 | quantas iterações esperar antes de avaliar o modelo (0: nunca) |
| Log_interval | 10 | Quantas iterações esperar antes da perda de treinamento (0: nunca) |
| Sample_interval | 1000 | quantas iterações esperar antes de salvar a amostra (0: nunca) |
| Sample_size | 12 | Número de imagens para amostrar em cada intervalo de amostragem |
Licenciado sob um Creative Commons Attribution-NonCommercial 4.0 International.
Exceto quando observado de outra forma, esse conteúdo é publicado sob uma licença CC BY-NC, o que significa que você pode copiar, remixar, transformar e construir o conteúdo, desde que não use o material para fins comerciais e forneça crédito apropriado e forneça um link para a licença.
Se você usar este código para sua pesquisa, cite nossos trabalhos EdgEConnect: Imagem generativa que inclui com o aprendizado de borda adversário ou edgeconnect: estrutura de imagem guiada por imagem usando previsão de borda:
@inproceedings{nazeri2019edgeconnect,
title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning},
author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
journal={arXiv preprint},
year={2019},
}
@InProceedings{Nazeri_2019_ICCV,
title = {EdgeConnect: Structure Guided Image Inpainting using Edge Prediction},
author = {Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops},
month = {Oct},
year = {2019}
}