O Torchvision Model Zoo fornece número de implementações de várias arquiteturas de ponta, no entanto, a maioria delas é definida e implementada para o ImageNet. Geralmente, é simples usar os modelos fornecidos em outros conjuntos de dados, mas alguns casos exigem configuração manual.
Por exemplo, muito poucos repositórios de Pytorch com Resnets no CIFAR10 fornecem a implementação conforme descrito no artigo original. Se você apenas usar os modelos da Torchvision no CIFAR10, obterá o modelo que difere em número de camadas e parâmetros . Isso é inaceitável se você deseja comparar diretamente o Resnet-S no CIFAR10 com o papel original. O objetivo deste repositório é fornecer uma implementação válida de Pytorch de Resnet-S para CIFAR10, conforme descrito no artigo original. Os seguintes modelos são fornecidos:
| Nome | # camadas | # params | Teste Err (papel) | Teste err (este impl.) |
|---|---|---|---|---|
| Resnet20 | 20 | 0,27m | 8,75% | 8,27% |
| Resnet32 | 32 | 0,46m | 7,51% | 7,37% |
| Resnet44 | 44 | 0,66m | 7,17% | 6,90% |
| Resnet56 | 56 | 0,85m | 6,97% | 6,61% |
| Resnet110 | 110 | 1,7m | 6,43% | 6,32% |
| Resnet1202 | 1202 | 19,4m | 7,93% | 6,18% |
Esta implementação corresponde à descrição do artigo original, com um erro de teste comparável ou melhor.
git clone https://github.com/akamaster/pytorch_resnet_cifar10
cd pytorch_resnet_cifar10
chmod +x run.sh && ./run.shNossa implementação segue o artigo de maneira direta com algumas advertências: primeiro , o treinamento no artigo usa divisão de 45k/5k de trem/validação nos dados do trem e seleciona o modelo de melhor desempenho com base no desempenho no conjunto de validação. Não realizamos testes de validação; Se você precisar comparar seus resultados no resnet frente a frente com o papel orginal, lembre-se disso. Segundo , se você deseja treinar Resnet1202, lembre -se de que precisa de memória de 16 GB na GPU.
Se você achar essa implementação útil e deseja citar/mencionar esta página, aqui está uma citação Bibtex:
@misc { Idelbayev18a ,
author = " Yerlan Idelbayev " ,
title = " Proper {ResNet} Implementation for {CIFAR10/CIFAR100} in {PyTorch} " ,
howpublished = " url{https://github.com/akamaster/pytorch_resnet_cifar10} " ,
note = " Accessed: 20xx-xx-xx "
}