¿Por qué Functorch? | Guía de instalación | Transformaciones | Documentación | Planes futuros
Esta biblioteca se encuentra actualmente en un gran desarrollo: si tiene sugerencias sobre la API o los casos de uso que desea cubrir, abra un problema de GitHub o comuníquese. Nos encantaría saber cómo estás usando la biblioteca.
functorch es transformaciones de función compuesta como Jax para Pytorch.
Su objetivo es proporcionar transformaciones compuestas de vmap y grad que funcionen con módulos Pytorch y Pytorch Autograd con un buen rendimiento de modo ansioso.
Además, existe una funcionalidad experimental para rastrear estas transformaciones utilizando FX para capturar los resultados de estas transformaciones con anticipación. Esto nos permitiría compilar los resultados de VMAP o Grad para mejorar el rendimiento.
Hay una serie de casos de uso que son difíciles de hacer en Pytorch hoy:
La composición de transformaciones vmap , grad , vjp y jvp nos permite expresar lo anterior sin diseñar un subsistema separado para cada uno. Esta idea de transformaciones de funciones compuestas proviene del marco JAX.
Hay dos formas de instalar Functorch:
Recomendamos probar primero la beta Functorch.
Siga las instrucciones en este cuaderno de colab.
A partir del 21/09/2022, functorch viene instalado junto con un binario nocturno de Pytorch. Instale un binario Pytorch de vista previa (nocturna); Consulte https://pytorch.org/ para obtener instrucciones.
Una vez que hayas hecho eso, ejecute un check de cordura rápido en Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ()) A partir del 21/09/2022, functorch viene instalado junto a Pytorch y está en el árbol de origen de Pytorch. Instale Pytorch desde la fuente, entonces, podrá import functorch .
Intente ejecutar algunas pruebas para asegurarse de que todo esté bien:
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -vAotautograd tiene algunos requisitos opcionales adicionales. Puede instalarlos a través de:
pip install networkx Para ejecutar pruebas de Functorch, instale nuestras dependencias de prueba ( expecttest , pyyaml ).
Sigue las instrucciones aquí
Prerrequisito: instalar pytorch
pip install functorchFinalmente, ejecute un check de cordura rápido en Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())En este momento, apoyamos las siguientes transformaciones:
grad , vjp , jvp ,jacrev , jacfwd , hessianvmapAdemás, tenemos algunas utilidades para trabajar con módulos Pytorch.
make_functional(model)make_functional_with_buffers(model) Nota: vmap impone restricciones en el código en el que se puede usar. Para más detalles, lea su documento.
vmap(func)(*inputs) es una transformación que agrega una dimensión a todas las operaciones tensoras en func . vmap(func) Devuelve una nueva función que mapea func sobre alguna dimensión (predeterminada: 0) de cada tensor en inputs .
vmap es útil para ocultar las dimensiones de lotes: uno puede escribir una función func que se ejecuta en ejemplos y luego llevarla a una función que puede tomar lotes de ejemplos con vmap(func) , lo que lleva a una experiencia de modelado más simple:
from functorch import vmap
batch_size , feature_size = 3 , 5
weights = torch . randn ( feature_size , requires_grad = True )
def model ( feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
examples = torch . randn ( batch_size , feature_size )
result = vmap ( model )( examples ) grad(func)(*inputs) Asume que func devuelve un tensor de elemento único. Calcula los gradientes de la salida de funct a inputs[0] .
from functorch import grad
x = torch . randn ([])
cos_x = grad ( lambda x : torch . sin ( x ))( x )
assert torch . allclose ( cos_x , x . cos ())
# Second-order gradients
neg_sin_x = grad ( grad ( lambda x : torch . sin ( x )))( x )
assert torch . allclose ( neg_sin_x , - x . sin ()) Cuando se compone con vmap , se puede utilizar grad para calcular los gradientes por muestra:
from functorch import vmap
batch_size , feature_size = 3 , 5
def model ( weights , feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
def compute_loss ( weights , example , target ):
y = model ( weights , example )
return (( y - target ) ** 2 ). mean () # MSELoss
weights = torch . randn ( feature_size , requires_grad = True )
examples = torch . randn ( batch_size , feature_size )
targets = torch . randn ( batch_size )
inputs = ( weights , examples , targets )
grad_weight_per_example = vmap ( grad ( compute_loss ), in_dims = ( None , 0 , 0 ))( * inputs ) La transformación vjp aplica func a inputs y devuelve una nueva función que calcula VJP dados algunos tensores cotangents .
from functorch import vjp
outputs , vjp_fn = vjp ( func , inputs ); vjps = vjp_fn ( * cotangents ) La transformación de jvp calcula los productos de vectores jacobianos y también se conoce como "AD de modo avanzado". No es una función de orden superior a diferencia de la mayoría de las otras transformaciones, pero devuelve las salidas de func(inputs) , así como los jvp S.
from functorch import jvp
x = torch . randn ( 5 )
y = torch . randn ( 5 )
f = lambda x , y : ( x * y )
_ , output = jvp ( f , ( x , y ), ( torch . ones ( 5 ), torch . ones ( 5 )))
assert torch . allclose ( output , x + y ) La transformación jacrev devuelve una nueva función que toma en x y devuelve el jacobiano de torch.sin con respecto a x usando AD de modo inverso.
from functorch import jacrev
x = torch . randn ( 5 )
jacobian = jacrev ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Use jacrev para calcular el jacobiano. Esto puede componerse con VMAP para producir jacobianos por lotes:
x = torch . randn ( 64 , 5 )
jacobian = vmap ( jacrev ( torch . sin ))( x )
assert jacobian . shape == ( 64 , 5 , 5 ) jacfwd es un reemplazo de jacrev que calcula a los jacobianos utilizando AD de modo avanzado:
from functorch import jacfwd
x = torch . randn ( 5 )
jacobian = jacfwd ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Componer jacrev consigo mismo o jacfwd puede producir hessians:
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hessian0 = jacrev ( jacrev ( f ))( x )
hessian1 = jacfwd ( jacrev ( f ))( x ) El hessian es una función de conveniencia que combina jacfwd y jacrev :
from functorch import hessian
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hess = hessian ( f )( x ) También podemos rastrear estas transformaciones para capturar los resultados como un nuevo código usando make_fx . También hay integración experimental con el compilador NNC (¡solo funciona en CPU por ahora!).
from functorch import make_fx , grad
def f ( x ):
return torch . sin ( x ). sum ()
x = torch . randn ( 100 )
grad_f = make_fx ( grad ( f ))( x )
print ( grad_f . code )
def forward ( self , x_1 ):
sin = torch . ops . aten . sin ( x_1 )
sum_1 = torch . ops . aten . sum ( sin , None ); sin = None
cos = torch . ops . aten . cos ( x_1 ); x_1 = None
_tensor_constant0 = self . _tensor_constant0
mul = torch . ops . aten . mul ( _tensor_constant0 , cos ); _tensor_constant0 = cos = None
return mulA veces es posible que desee realizar una transformación con respecto a los parámetros y/o buffers de un módulo nn. Esto puede suceder, por ejemplo, en:
Nuestra solución a esto en este momento es una API que, dada un módulo nn., crea una versión sin estado que puede llamarse como una función.
make_functional(model) Devuelve una versión funcional del model y el model.parameters()make_functional_with_buffers(model) Devuelve una versión funcional del model y el model.parameters() y model.buffers() .Aquí hay un ejemplo en el que calculamos los gradientes por muestra usando una capa lineal nn:
import torch
from functorch import make_functional , vmap , grad
model = torch . nn . Linear ( 3 , 3 )
data = torch . randn ( 64 , 3 )
targets = torch . randn ( 64 , 3 )
func_model , params = make_functional ( model )
def compute_loss ( params , data , targets ):
preds = func_model ( params , data )
return torch . mean (( preds - targets ) ** 2 )
per_sample_grads = vmap ( grad ( compute_loss ), ( None , 0 , 0 ))( params , data , targets ) Si está haciendo un conjunto de modelos, puede encontrar útil combine_state_for_ensemble .
Para obtener más documentación, consulte nuestro sitio web de documentos.
torch._C._functorch.dump_tensor : descarga las teclas de despacho en Stack torch._C._functorch._set_vmap_fallback_warning_enabled(False) Si el spam de advertencia VMAP lo molesta.
En el estado final, nos gustaría pasar por la transmisión en Pytorch una vez que solucionamos los detalles del diseño. Para determinar los detalles, necesitamos su ayuda; envíenos sus casos de uso comenzando una conversación en el rastreador de problemas o probando nuestro proyecto.
Functorch tiene una licencia de estilo BSD, como se encuentra en el archivo de licencia.
Si usa Functorch en su publicación, cíquelo utilizando la siguiente entrada de Bibtex.
@Misc { functorch2021 ,
author = { Horace He, Richard Zou } ,
title = { functorch: JAX-like composable function transforms for PyTorch } ,
howpublished = { url{https://github.com/pytorch/functorch} } ,
year = { 2021 }
}