Accueillir! Pour les nouveaux projets, je recommande maintenant fortement d'utiliser mon nouveau projet JaxTyping à la place. Il prend en charge Pytorch, ne dépend pas réellement de Jax, et contrairement à la torche, il est compatible avec les vérificateurs de type statique. :)
Tournez ceci:
def batch_outer_product ( x : torch . Tensor , y : torch . Tensor ) -> torch . Tensor :
# x has shape (batch, x_channels)
# y has shape (batch, y_channels)
# return has shape (batch, x_channels, y_channels)
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )dans ceci:
def batch_outer_product ( x : TensorType [ "batch" , "x_channels" ],
y : TensorType [ "batch" , "y_channels" ]
) -> TensorType [ "batch" , "x_channels" , "y_channels" ]:
return x . unsqueeze ( - 1 ) * y . unsqueeze ( - 2 )avec programmatique vérifiant que la spécification de forme (dtype, ...) est respectée.
Bugs byes bogs! Dites bonjour à la documentation claire et claire de votre code.
Si (comme moi) vous vous retrouvez à joncher votre code avec des commentaires comme # x has shape (batch, hidden_state) ou des instructions comme assert x.shape == y.shape , juste pour garder une trace de la forme que tout est, alors c'est pour vous.
pip install torchtypingNécessite Python> = 3,7 et Pytorch> = 1,7.0.
Si vous utilisez typeguard , ce doit être une version <3.0.0.
torchtyping permet l'annotation de type:
... ;torchtyping sont très extensibles. Si typeguard est (éventuellement) installé, alors au moment de l'exécution, les types peuvent être vérifiés pour s'assurer que les tenseurs sont vraiment de la forme annoncée, DType, etc.
# EXAMPLE
from torch import rand
from torchtyping import TensorType , patch_typeguard
from typeguard import typechecked
patch_typeguard () # use before @typechecked
@ typechecked
def func ( x : TensorType [ "batch" ],
y : TensorType [ "batch" ]) -> TensorType [ "batch" ]:
return x + y
func ( rand ( 3 ), rand ( 3 )) # works
func ( rand ( 3 ), rand ( 1 ))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3. typeguard a également un crochet d'importation qui peut être utilisé pour tester automatiquement un module entier, sans avoir besoin d'ajouter manuellement @typeguard.typechecked Decorators.
Si vous n'utilisez pas typeguard , torchtyping.patch_typeguard() peut être omis complètement, et torchtyping vient d'être utilisé à des fins de documentation. Si vous n'utilisez pas déjà typeguard pour votre programmation Python ordinaire, envisagez fortement de l'utiliser. C'est un excellent moyen d'écraser les insectes. typeguard et torchtyping s'intègrent également à pytest , donc si vous êtes préoccupé par une pénalité de performance, ils peuvent être activés lors des tests uniquement.
torchtyping . TensorType [ shape , dtype , layout , details ]Le cœur de la bibliothèque.
Chacun de shape , dtype , layout , details sont facultatifs.
shape peut être l'un des:int : La dimension doit être exactement de cette taille. S'il est -1 , toute taille est autorisée.str : La taille de la dimension passée au moment de l'exécution sera liée à ce nom, et tous les tenseurs ont vérifié que les tailles sont cohérentes.... : Un nombre arbitraire de dimensions de toutes tailles.str: int (techniquement c'est une tranche), combinant le comportement str et int . (Juste un str seul équivaut à str: -1 .)str: str , auquel cas la taille de la dimension passée à l'exécution sera liée aux deux noms, et toutes les dimensions avec l'un ou l'autre nom doivent avoir la même taille. (Certaines personnes aiment l'utiliser comme un moyen d'associer plusieurs noms à une dimension, à des fins de documentation supplémentaires.)str: ... auquel cas les multiples dimensions correspondant à ... seront liées au nom spécifié par str , et à nouveau vérifié la cohérence entre les arguments.None , qui, lorsqu'il est utilisé en conjonction avec is_named , indique une dimension qui ne doit pas avoir de nom au sens des tenseurs nommés.None: int , combinant à la fois None et un comportement int . (Juste un None n'est équivalent à None: -1 .)None: str Pair, combinant le comportement None et str . (C'est-à-dire qu'il ne doit pas avoir de dimension nommée, mais doit être d'une taille cohérente avec d'autres utilisations de la chaîne.)-1 typing.Any .TensorType["batch": ..., "length": 10, "channels", -1] . Si vous souhaitez simplement spécifier le nombre de dimensions, utilisez par exemple TensorType[-1, -1, -1] pour un tenseur tridimensionnel.dtype peut être l'un des:torch.float32 , torch.float64 etc.int , bool , float , qui sont convertis en leurs types de pytorch correspondants. float est spécifiquement interprété comme torch.get_default_dtype() , qui est généralement float32 .layout peut être torch.strided ou torch.sparse_coo , pour les tenseurs denses et clairsemés respectivement.details offre un moyen de passer un nombre arbitraire de drapeaux supplémentaires qui personnalisent et étendent torchtyping . Deux drapeaux sont intégrés par défaut. torchtyping.is_named provoque la vérification des noms des dimensions du tenseur, et torchtyping.is_float peut être utilisé pour vérifier que les types de points flottants arbitraires sont passés. (Plutôt que de simplement une discussion spécifique comme pour personnaliser TensorType[torch.float32] .) Pour une discussion sur la façon de personnaliser torchtyping avec vos propres details , voir le document supplémentaire.[] . Par exemple, TensorType["batch": ..., "length", "channels", float, is_named] . torchtyping . patch_typeguard () torchtyping s'intègre à typeguard pour effectuer la vérification du type d'exécution. torchtyping.patch_typeguard() doit être appelé au niveau global, et patchera typeguard pour vérifier les TensorType .
Cette fonction est sûre à exécuter plusieurs fois. (Il ne fait rien après la première manche).
@typeguard.typechecked , alors torchtyping.patch_typeguard() doit être appelé à tout moment avant d'utiliser @typeguard.typechecked . Par exemple, vous pouvez l'appeler au début de chaque fichier en utilisant torchtyping .typeguard.importhook.install_import_hook , alors torchtyping.patch_typeguard() doit être appelé à tout moment avant de définir les fonctions que vous souhaitez vérifier. Par exemple, vous pouvez appeler torchtyping.patch_typeguard() une seule fois, en même temps que le crochet d'importation typeguard . (L'ordre du crochet et le patch n'ont pas d'importance.)typeguard , torchtyping.patch_typeguard() peut être omis complètement, et torchtyping vient d'être utilisé à des fins de documentation. pytest --torchtyping-patch-typeguard torchtyping propose un plugin pytest pour exécuter automatiquement torchtyping.patch_typeguard() avant vos tests. pytest découvrira automatiquement le plugin, il vous suffit de passer l'indicateur --torchtyping-patch-typeguard pour l'activer. Les packages peuvent ensuite être transmis à typeguard comme d'habitude, soit en utilisant @typeguard.typechecked , le crochet d'importation de typeguard , soit le pytest Flag --typeguard-packages="your_package_here" .
Voir la documentation supplémentaire pour:
flake8 et mypy Compatibilité;torchtyping ;