Bem-vindo! Para novos projetos, agora recomendo fortemente o uso do meu projeto Jaxtyping mais recente. Ele suporta Pytorch, na verdade não depende do JAX e, diferentemente da Torchtyping, é compatível com verificadores do tipo estático. :)
Vire isso:
def batch_outer_product ( x : torch . Tensor , y : torch . Tensor ) -> torch . Tensor :
# x has shape (batch, x_channels)
# y has shape (batch, y_channels)
# return has shape (batch, x_channels, y_channels)
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )nisso:
def batch_outer_product ( x : TensorType [ "batch" , "x_channels" ],
y : TensorType [ "batch" , "y_channels" ]
) -> TensorType [ "batch" , "x_channels" , "y_channels" ]:
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )Com a verificação programática de que a forma (dtype, ...) é atendida.
Tchau bugs! Diga olá para documentação limpa e forçada do seu código.
Se (como eu) você se encontrar espalhando seu código com comentários como # x has shape (batch, hidden_state) ou declarações como assert x.shape == y.shape , apenas para acompanhar o que é tudo, então isso é para você.
pip install torchtypingRequer python> = 3.7 e pytorch> = 1.7.0.
Se estiver usando typeguard , deve ser uma versão <3.0.0.
torchtyping permite a anotação do tipo:
... ;torchtyping é altamente extensível. Se typeguard estiver (opcionalmente) instalado, em tempo de execução, os tipos poderão ser verificados para garantir que os tensores realmente sejam da forma anunciada, dtype, etc.
# EXAMPLE
from torch import rand
from torchtyping import TensorType , patch_typeguard
from typeguard import typechecked
patch_typeguard () # use before @typechecked
@ typechecked
def func ( x : TensorType [ "batch" ],
y : TensorType [ "batch" ]) -> TensorType [ "batch" ]:
return x + y
func ( rand ( 3 ), rand ( 3 )) # works
func ( rand ( 3 ), rand ( 1 ))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3. typeguard também possui um gancho de importação que pode ser usado para testar automaticamente um módulo inteiro, sem precisar adicionar manualmente @typeguard.typechecked Decorators.
Se você não estiver usando typeguard , torchtyping.patch_typeguard() pode ser omitida por completo, e torchtyping usada apenas para fins de documentação. Se você ainda não está usando typeguard para sua programação regular do Python, considere -o usá -lo fortemente. É uma ótima maneira de esmagar bugs. typeguard e torchtyping também se integram ao pytest ; portanto, se você estiver preocupado com qualquer penalidade de desempenho, eles poderão ser ativados apenas durante os testes.
torchtyping . TensorType [ shape , dtype , layout , details ]O núcleo da biblioteca.
Cada uma de shape , dtype , layout e details são opcionais.
shape pode ser de:int : a dimensão deve ser exatamente desse tamanho. Se for -1 , qualquer tamanho será permitido.str : O tamanho da dimensão aprovado em tempo de execução estará vinculado a esse nome, e todos os tensores verificaram se os tamanhos são consistentes.... : Um número arbitrário de dimensões de qualquer tamanho.str: int (tecnicamente é uma fatia), combinando o comportamento str e int . (Apenas um str por conta própria é equivalente a str: -1 .)str: str , caso em que o tamanho da dimensão passada em tempo de execução estará vinculado aos dois nomes, e todas as dimensões com qualquer um dos nomes devem ter o mesmo tamanho. (Algumas pessoas gostam de usar isso como uma maneira de associar vários nomes a uma dimensão, para fins extras de documentação.)str: ... par, nesse caso, as múltiplas dimensões correspondentes a ... estarão vinculadas ao nome especificado por str e novamente verificadas quanto à consistência entre os argumentos.None , que quando usado em conjunto com is_named abaixo, indica uma dimensão que não deve ter um nome no sentido de tensores nomeados.None: int , combinando o comportamento None e int . (Apenas um None por si só é equivalente a None: -1 .)None: str combinando o comportamento None e str . (Ou seja, não deve ter uma dimensão nomeada, mas deve ter um tamanho consistente com outros usos da string.)typing.Any : qualquer tamanho é permitido para esta dimensão (equivalente a -1 ).TensorType["batch": ..., "length": 10, "channels", -1] . Se você deseja apenas especificar o número de dimensões, use, por exemplo TensorType[-1, -1, -1] para um tensor tridimensional.dtype pode ser qualquer um dos:torch.float32 , torch.float64 etc.int , bool , float , que são convertidos em seus tipos de pytorch correspondentes. float é especificamente interpretado como torch.get_default_dtype() , que geralmente é float32 .layout pode ser torch.strided torch.sparse_coodetails oferece uma maneira de passar um número arbitrário de sinalizadores adicionais que personalizam e estendem torchtyping . Dois sinalizadores são embutidos por padrão. torchtyping.is_named faz com TensorType[torch.float32] os nomes das dimensões do tensor sejam verificados e torchtyping.is_float pode ser usada para verificar se os tipos arbitrários de details torchtyping são transmitidos.[] . Por exemplo TensorType["batch": ..., "length", "channels", float, is_named] . torchtyping . patch_typeguard () torchtyping se integra ao typeguard para executar a verificação do tipo de tempo de execução. torchtyping.patch_typeguard() deve ser chamado em nível global e typeguard para verificar TensorType s.
Esta função é segura para executar várias vezes. (Não faz nada após a primeira corrida).
@typeguard.typechecked , então torchtyping.patch_typeguard() deve ser chamado a qualquer momento antes de usar @typeguard.typechecked . Por exemplo, você pode chamá -lo no início de cada arquivo usando torchtyping .typeguard.importhook.install_import_hook , então torchtyping.patch_typeguard() deve ser chamado a qualquer momento antes de definir as funções que você deseja verificadas. Por exemplo, você pode ligar para torchtyping.patch_typeguard() apenas uma vez, ao mesmo tempo que o gancho de importação typeguard . (A ordem do gancho e o patch não importa.)typeguard , torchtyping.patch_typeguard() pode ser omitida por completo, e torchtyping usada apenas para fins de documentação. pytest --torchtyping-patch-typeguard torchtyping oferece um plug -in pytest para executar automaticamente torchtyping.patch_typeguard() antes dos testes. pytest descobrirá automaticamente o plug-in, você só precisa passar pelo --torchtyping-patch-typeguard para ativá-lo. Os pacotes podem ser transmitidos para typeguard normalmente, usando @typeguard.typechecked , o gancho de importação do typeguard ou o sinalizador pytest --typeguard-packages="your_package_here" .
Veja a documentação adicional para:
flake8 e mypy Compatibilidade;torchtyping ;