Почему Functorch? | Руководство по установке | Преобразования | Документация | Планы на будущее
Эта библиотека в настоящее время находится под тяжелой разработкой - если у вас есть предложения по API или вариантам использования, которые вы хотели бы покрыть, пожалуйста, откройте проблему GitHub или обратитесь. Мы хотели бы услышать, как вы используете библиотеку.
functorch -это JAX-подобные композиционные преобразования функции для Pytorch.
Он направлен на обеспечение композиционного vmap и grad преобразования, которые работают с модулями Pytorch и автоградом Pytorch с хорошей производительности с нетерпением.
Кроме того, существует экспериментальная функциональность, чтобы проследить эти преобразования с использованием FX, чтобы заранее захватить результаты этих преобразований. Это позволило бы нам составить результаты VMAP или GRAD для повышения производительности.
Есть ряд вариантов использования, которые сегодня сложно сделать в Pytorch:
Создание преобразований vmap , grad , vjp и jvp позволяет нам выражать вышеперечисленное, не создавая отдельную подсистему для каждой. Эта идея композиционных преобразований функции поступает из фреймворка JAX.
Есть два способа установить Functorch:
Мы рекомендуем сначала попробовать бета -версию Functorch.
Следуйте инструкциям в этой ноутбуке Colab
По состоянию на 21.09.2022 functorch установлена вместе с ночным двоичным из бинарного питания. Пожалуйста, установите предварительный просмотр (ночной) двоичный файл Pytorch; Смотрите https://pytorch.org/ для инструкций.
После того, как вы это сделаете, запустите быструю проверку здравомыслия в Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ()) По состоянию на 21.09.2022 functorch установлен вместе с Pytorch и находится в дереве источника Pytorch. Пожалуйста, установите Pytorch из Source, тогда вы сможете import functorch .
Попробуйте запустить некоторые тесты, чтобы убедиться, что все в порядке:
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -vAotautograd имеет некоторые дополнительные дополнительные требования. Вы можете установить их через:
pip install networkx Чтобы запустить тесты Functorch, пожалуйста, установите наши тестовые зависимости ( expecttest , pyyaml ).
Следуйте инструкциям здесь
Обязательное условие: установить Pytorch
pip install functorchНаконец, запустите быструю проверку здравомыслия в Python:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())Прямо сейчас мы поддерживаем следующие преобразования:
grad , vjp , jvp ,jacrev , jacfwd , hessianvmapКроме того, у нас есть утилиты для работы с модулями Pytorch.
make_functional(model)make_functional_with_buffers(model) Примечание: vmap налагает ограничения на код, на котором он может быть использован. Для получения более подробной информации, пожалуйста, прочитайте его Docstring.
vmap(func)(*inputs) - это преобразование, которое добавляет измерение ко всем операциям по тензору в func . vmap(func) возвращает новую функцию, которая отображает func по некоторому измерению (по умолчанию: 0) каждого тензора на inputs .
vmap полезен для сокрытия размеров пакетов: можно написать функцию func , которая работает на примерах, а затем поднять его на функцию, которая может взять партии примеров с vmap(func) , что приводит к более простому моделированию:
from functorch import vmap
batch_size , feature_size = 3 , 5
weights = torch . randn ( feature_size , requires_grad = True )
def model ( feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
examples = torch . randn ( batch_size , feature_size )
result = vmap ( model )( examples ) grad(func)(*inputs) предполагает, что func возвращает тенсор с одним элементом. Он вычисляет градиенты выходного сигнала func wrt к inputs[0] .
from functorch import grad
x = torch . randn ([])
cos_x = grad ( lambda x : torch . sin ( x ))( x )
assert torch . allclose ( cos_x , x . cos ())
# Second-order gradients
neg_sin_x = grad ( grad ( lambda x : torch . sin ( x )))( x )
assert torch . allclose ( neg_sin_x , - x . sin ()) При составлении с vmap grad может использоваться для вычисления градиентов для выборки:
from functorch import vmap
batch_size , feature_size = 3 , 5
def model ( weights , feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
def compute_loss ( weights , example , target ):
y = model ( weights , example )
return (( y - target ) ** 2 ). mean () # MSELoss
weights = torch . randn ( feature_size , requires_grad = True )
examples = torch . randn ( batch_size , feature_size )
targets = torch . randn ( batch_size )
inputs = ( weights , examples , targets )
grad_weight_per_example = vmap ( grad ( compute_loss ), in_dims = ( None , 0 , 0 ))( * inputs ) Преобразование vjp применяет func к inputs и возвращает новую функцию, которая вычисляет VJP, данные некоторые тензоры cotangents .
from functorch import vjp
outputs , vjp_fn = vjp ( func , inputs ); vjps = vjp_fn ( * cotangents ) Преобразования jvp вычисляют Jacobian-Vector-Products, а также известен как «прямого режима AD». Это не функция высшего порядка, в отличие от большинства других преобразований, но возвращает выходы func(inputs) , а также jvp S.
from functorch import jvp
x = torch . randn ( 5 )
y = torch . randn ( 5 )
f = lambda x , y : ( x * y )
_ , output = jvp ( f , ( x , y ), ( torch . ones ( 5 ), torch . ones ( 5 )))
assert torch . allclose ( output , x + y ) Преобразование jacrev возвращает новую функцию, которая принимает x и возвращает Jacobian of torch.sin по отношению к x с использованием AD с обратным режимом.
from functorch import jacrev
x = torch . randn ( 5 )
jacobian = jacrev ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Используйте jacrev , чтобы вычислить якобиан. Это может быть составлено с VMAP для производства пакетных якобианцев:
x = torch . randn ( 64 , 5 )
jacobian = vmap ( jacrev ( torch . sin ))( x )
assert jacobian . shape == ( 64 , 5 , 5 ) jacfwd -это замена для jacrev , которая вычисляет якобианцев с использованием AD прямого режима:
from functorch import jacfwd
x = torch . randn ( 5 )
jacobian = jacfwd ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) Создание jacrev с самим собой или jacfwd может производить гессесцев:
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hessian0 = jacrev ( jacrev ( f ))( x )
hessian1 = jacfwd ( jacrev ( f ))( x ) hessian - удобная функция, которая объединяет jacfwd и jacrev :
from functorch import hessian
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hess = hessian ( f )( x ) Мы также можем проследить эти преобразования, чтобы захватить результаты в качестве нового кода, используя make_fx . Существует также экспериментальная интеграция с компилятором NNC (пока работает только на процессоре!).
from functorch import make_fx , grad
def f ( x ):
return torch . sin ( x ). sum ()
x = torch . randn ( 100 )
grad_f = make_fx ( grad ( f ))( x )
print ( grad_f . code )
def forward ( self , x_1 ):
sin = torch . ops . aten . sin ( x_1 )
sum_1 = torch . ops . aten . sum ( sin , None ); sin = None
cos = torch . ops . aten . cos ( x_1 ); x_1 = None
_tensor_constant0 = self . _tensor_constant0
mul = torch . ops . aten . mul ( _tensor_constant0 , cos ); _tensor_constant0 = cos = None
return mulИногда вы можете выполнить преобразование в отношении параметров и/или буферов NN.Module. Это может произойти, например, в:
Наше решение этого прямо сейчас - API, который, учитывая NN.Module, создает его версию без состояния, которую можно назвать как функция.
make_functional(model) возвращает функциональную версию model и model.parameters()make_functional_with_buffers(model) возвращает функциональную версию model и model.parameters() и model.buffers() .Вот пример, в котором мы вычисляем градиенты для выборки, используя NN.Linear Layer:
import torch
from functorch import make_functional , vmap , grad
model = torch . nn . Linear ( 3 , 3 )
data = torch . randn ( 64 , 3 )
targets = torch . randn ( 64 , 3 )
func_model , params = make_functional ( model )
def compute_loss ( params , data , targets ):
preds = func_model ( params , data )
return torch . mean (( preds - targets ) ** 2 )
per_sample_grads = vmap ( grad ( compute_loss ), ( None , 0 , 0 ))( params , data , targets ) Если вы делаете ансамбль моделей, вы можете найти combine_state_for_ensemble полезным.
Для получения дополнительной документации см. Наш сайт DOCS.
torch._C._functorch.dump_tensor : дамп клавиш для отправки на стеке torch._C._functorch._set_vmap_fallback_warning_enabled(False) , если вас беспокоит спам WMAP.
В конце концов, мы хотели бы вверх по течению в Pytorch, как только мы разгоняем детали дизайна. Чтобы выяснить детали, нам нужна ваша помощь - пожалуйста, пришлите нам ваши варианты использования, начав разговор в трекере выпуска или попробовав наш проект.
Functorch имеет лицензию в стиле BSD, как найдено в файле лицензии.
Если вы используете Functorch в своей публикации, укажите его, используя следующую запись Bibtex.
@Misc { functorch2021 ,
author = { Horace He, Richard Zou } ,
title = { functorch: JAX-like composable function transforms for PyTorch } ,
howpublished = { url{https://github.com/pytorch/functorch} } ,
year = { 2021 }
}