Pourquoi Functorch? | Guide d'installation | Transformations | Documentation | Plans futurs
Cette bibliothèque est actuellement en cours de développement intense - si vous avez des suggestions sur l'API ou des cas d'utilisation que vous souhaitez être couverts, veuillez ouvrir un problème GitHub ou contacter. Nous aimerions savoir comment vous utilisez la bibliothèque.
functorch est des transformations composables de type JAX pour Pytorch.
Il vise à fournir des transformations composables vmap et grad qui fonctionnent avec des modules Pytorch et Pytorch Autograd avec de bonnes performances en mode désireuses.
De plus, il existe des fonctionnalités expérimentales pour retrouver ces transformations en utilisant FX afin de capturer les résultats de ces transformations à l'avance. Cela nous permettrait de compiler les résultats de la VMAP ou du gradation pour améliorer les performances.
Il y a un certain nombre de cas d'utilisation qui sont difficiles à faire à Pytorch aujourd'hui:
La composition des transformations vmap , grad , vjp et jvp nous permet d'exprimer ce qui précède sans concevoir un sous-système distinct pour chacun. Cette idée des transformations de fonction composables provient du framework JAX.
Il existe deux façons d'installer Functorch:
Nous vous recommandons d'abord d'essayer le Functorch Beta.
Suivez les instructions dans ce cahier Colab
Au 21/09/2022, functorch est installé aux côtés d'un binaire pytorch nocturne. Veuillez installer un prévisualisation (Nightly) Pytorch Binary; Voir https://pytorch.org/ pour les instructions.
Une fois que vous avez fait cela, effectuez une vérification rapide de la santé mentale dans Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ()) Au 21/09/2022, functorch est installé aux côtés de Pytorch et se trouve dans l'arbre source de Pytorch. Veuillez installer Pytorch à partir de la source, puis, vous pourrez import functorch .
Essayez d'exécuter quelques tests pour vous assurer que tout va bien:
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -vAotautograd a des exigences facultatives supplémentaires. Vous pouvez les installer via:
pip install networkx Pour exécuter des tests Functorch, veuillez installer nos dépendances de test ( expecttest , pyyaml ).
Suivez les instructions ici
Préalable: installer pytorch
pip install functorchEnfin, effectuez un chèque de santé mentale rapide dans Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())À l'heure actuelle, nous soutenons les transformations suivantes:
grad , vjp , jvp ,jacrev , jacfwd , hessianvmapDe plus, nous avons des services publics pour travailler avec des modules Pytorch.
make_functional(model)make_functional_with_buffers(model) Remarque: vmap impose des restrictions sur le code sur lequel il peut être utilisé. Pour plus de détails, veuillez lire son docstring.
vmap(func)(*inputs) est une transformation qui ajoute une dimension à toutes les opérations de tenseur dans func . vmap(func) renvoie une nouvelle fonction qui mappe func sur une dimension (par défaut: 0) de chaque tenseur dans inputs .
vmap est utile pour cacher les dimensions des lots: on peut écrire une fonction func qui fonctionne sur des exemples, puis la soulever vers une fonction qui peut prendre des lots d'exemples avec vmap(func) , conduisant à une expérience de modélisation plus simple:
from functorch import vmap
batch_size , feature_size = 3 , 5
weights = torch . randn ( feature_size , requires_grad = True )
def model ( feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
examples = torch . randn ( batch_size , feature_size )
result = vmap ( model )( examples ) grad(func)(*inputs) Supposons que func renvoie un tenseur à élément unique. Il calcule les gradients de la sortie de Func WRT aux inputs[0] .
from functorch import grad
x = torch . randn ([])
cos_x = grad ( lambda x : torch . sin ( x ))( x )
assert torch . allclose ( cos_x , x . cos ())
# Second-order gradients
neg_sin_x = grad ( grad ( lambda x : torch . sin ( x )))( x )
assert torch . allclose ( neg_sin_x , - x . sin ()) Lorsqu'il est composé avec vmap , grad peut être utilisé pour calculer les gradiements par échantillon:
from functorch import vmap
batch_size , feature_size = 3 , 5
def model ( weights , feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
def compute_loss ( weights , example , target ):
y = model ( weights , example )
return (( y - target ) ** 2 ). mean () # MSELoss
weights = torch . randn ( feature_size , requires_grad = True )
examples = torch . randn ( batch_size , feature_size )
targets = torch . randn ( batch_size )
inputs = ( weights , examples , targets )
grad_weight_per_example = vmap ( grad ( compute_loss ), in_dims = ( None , 0 , 0 ))( * inputs ) La transformation vjp applique func aux inputs et renvoie une nouvelle fonction qui calcule les VJP avec certains tenseurs cotangents .
from functorch import vjp
outputs , vjp_fn = vjp ( func , inputs ); vjps = vjp_fn ( * cotangents ) Le jvp transforme les produits jacobiens-vector-vector et est également connu sous le nom de "AD en mode avant". Ce n'est pas une fonction d'ordre supérieur contrairement à la plupart des autres transformations, mais il renvoie les sorties de func(inputs) ainsi que les jvp .
from functorch import jvp
x = torch . randn ( 5 )
y = torch . randn ( 5 )
f = lambda x , y : ( x * y )
_ , output = jvp ( f , ( x , y ), ( torch . ones ( 5 ), torch . ones ( 5 )))
assert torch . allclose ( output , x + y ) La transformée jacrev renvoie une nouvelle fonction qui prend x et renvoie le jacobien de torch.sin par rapport à x en utilisant une annonce en mode inverse.
from functorch import jacrev
x = torch . randn ( 5 )
jacobian = jacrev ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Utilisez jacrev pour calculer le jacobien. Cela peut être composé avec VMAP pour produire des Jacobians lotés:
x = torch . randn ( 64 , 5 )
jacobian = vmap ( jacrev ( torch . sin ))( x )
assert jacobian . shape == ( 64 , 5 , 5 ) jacfwd est un remplacement sans rendez-vous pour jacrev qui calcule les Jacobiens en utilisant la publicité en mode avant:
from functorch import jacfwd
x = torch . randn ( 5 )
jacobian = jacfwd ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Composer jacrev avec lui-même ou jacfwd peut produire des Hessiens:
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hessian0 = jacrev ( jacrev ( f ))( x )
hessian1 = jacfwd ( jacrev ( f ))( x ) Le hessian est une fonction de commodité qui combine jacfwd et jacrev :
from functorch import hessian
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hess = hessian ( f )( x ) Nous pouvons également retracer ces transformations afin de capturer les résultats sous forme de nouveau code à l'aide de make_fx . Il existe également une intégration expérimentale avec le compilateur NNC (ne fonctionne que sur CPU pour l'instant!).
from functorch import make_fx , grad
def f ( x ):
return torch . sin ( x ). sum ()
x = torch . randn ( 100 )
grad_f = make_fx ( grad ( f ))( x )
print ( grad_f . code )
def forward ( self , x_1 ):
sin = torch . ops . aten . sin ( x_1 )
sum_1 = torch . ops . aten . sum ( sin , None ); sin = None
cos = torch . ops . aten . cos ( x_1 ); x_1 = None
_tensor_constant0 = self . _tensor_constant0
mul = torch . ops . aten . mul ( _tensor_constant0 , cos ); _tensor_constant0 = cos = None
return mulParfois, vous voudrez peut-être effectuer une transformation par rapport aux paramètres et / ou aux tampons d'un module nn. Cela peut arriver par exemple dans:
Notre solution à cela en ce moment est une API qui, étant donné un nn.module, en crée une version apatride qui peut être appelée comme une fonction.
make_functional(model) renvoie une version fonctionnelle du model et du model.parameters()make_functional_with_buffers(model) Renvoie une version fonctionnelle du model et du model.parameters() et model.buffers() .Voici un exemple où nous calculons les gradiements par échantillon à l'aide d'une couche nn.linéaire:
import torch
from functorch import make_functional , vmap , grad
model = torch . nn . Linear ( 3 , 3 )
data = torch . randn ( 64 , 3 )
targets = torch . randn ( 64 , 3 )
func_model , params = make_functional ( model )
def compute_loss ( params , data , targets ):
preds = func_model ( params , data )
return torch . mean (( preds - targets ) ** 2 )
per_sample_grads = vmap ( grad ( compute_loss ), ( None , 0 , 0 ))( params , data , targets ) Si vous créez un ensemble de modèles, vous pouvez trouver combine_state_for_ensemble utile.
Pour plus de documentation, consultez notre site Web DOCS.
torch._C._functorch.dump_tensor : vide les touches de répartition sur la pile torch._C._functorch._set_vmap_fallback_warning_enabled(False) Si le spam d'avertissement VMAP vous dérange.
Dans l'état final, nous aimerions en amont cela dans Pytorch une fois que nous avons repoussé les détails de conception. Pour déterminer les détails, nous avons besoin de votre aide - veuillez nous envoyer vos cas d'utilisation en commençant une conversation dans le tracker du numéro ou en essayant notre projet.
Functorch a une licence de style BSD, comme le trouve le fichier de licence.
Si vous utilisez Functorch dans votre publication, veuillez le citer en utilisant l'entrée Bibtex suivante.
@Misc { functorch2021 ,
author = { Horace He, Richard Zou } ,
title = { functorch: JAX-like composable function transforms for PyTorch } ,
howpublished = { url{https://github.com/pytorch/functorch} } ,
year = { 2021 }
}