なぜfunctorch? |インストールガイド|変換|ドキュメント|将来の計画
このライブラリは現在激しい開発中です - カバーしたいAPIまたはユースケースについて提案がある場合は、GitHubの問題を開いたり、手を差し伸べてください。ライブラリの使用方法についてお聞きしたいと思います。
functorchは、PytorchのJaxのような構成可能な関数変換です。
PytorchモジュールとPytorch Autogradで動作するComposable vmapおよびgrad変換を、優れたモードパフォーマンスを提供することを目的としています。
さらに、これらの変換の結果を事前にキャプチャするために、FXを使用してこれらの変換を追跡する実験機能があります。これにより、VMAPまたはグレードの結果をコンパイルしてパフォーマンスを向上させることができます。
今日のPytorchで行うのが難しいユースケースがたくさんあります。
vmap 、 grad 、 vjp 、およびjvp変換を作成すると、それぞれに個別のサブシステムを設計せずに上記を表現できます。合成可能な関数変換のこのアイデアは、JAXフレームワークから来ています。
functorchをインストールするには2つの方法があります。
最初にfounctorchベータを試すことをお勧めします。
このコラブノートブックの指示に従ってください
9/21/2022の時点で、 functorch毎晩のPytorchバイナリと一緒にインストールされています。プレビュー(夜間)pytorchバイナリをインストールしてください。指示については、https://pytorch.org/を参照してください。
それを行ったら、Pythonで迅速な正気チェックを実行します。
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())9/21/2022の時点で、 functorch Pytorchと一緒に設置され、Pytorchソースツリーにあります。 SourceからPytorchをインストールしてください。その場合、 import functorchできます。
すべてが大丈夫であることを確認するために、いくつかのテストを実行してみてください:
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -vaotautogradには、いくつかの追加のオプション要件があります。あなたはそれらを介してそれらをインストールすることができます:
pip install networkx functorchテストを実行するには、テスト依存関係( expecttest 、 pyyaml )をインストールしてください。
ここで指示に従ってください
前提条件:Pytorchをインストールします
pip install functorch最後に、Pythonで迅速な正気チェックを実行します。
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())今、私たちは次の変換をサポートしています:
grad 、 vjp 、 jvp 、jacrev 、 jacfwd 、 hessianvmapさらに、Pytorchモジュールを使用するためのユーティリティがいくつかあります。
make_functional(model)make_functional_with_buffers(model)注: vmap 、使用できるコードに制限を課します。詳細については、そのdocstringをお読みください。
vmap(func)(*inputs)は、 funcのすべてのテンソル操作に次元を追加する変換です。 vmap(func) inputsの各テンソルの寸法(デフォルト:0)にfuncをマップする新しい関数を返します。
vmap 、バッチ寸法を隠すのに役立ちます。例で実行される関数funcを記述し、 vmap(func)を使用して例のバッチを取得できる関数に持ち上げ、よりシンプルなモデリングエクスペリエンスにつながることができます。
from functorch import vmap
batch_size , feature_size = 3 , 5
weights = torch . randn ( feature_size , requires_grad = True )
def model ( feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
examples = torch . randn ( batch_size , feature_size )
result = vmap ( model )( examples )grad(func)(*inputs) funcシングルエレメントテンソルを返すと仮定します。 FUNC WRTの出力の勾配をinputs[0]に計算します。
from functorch import grad
x = torch . randn ([])
cos_x = grad ( lambda x : torch . sin ( x ))( x )
assert torch . allclose ( cos_x , x . cos ())
# Second-order gradients
neg_sin_x = grad ( grad ( lambda x : torch . sin ( x )))( x )
assert torch . allclose ( neg_sin_x , - x . sin ()) vmapで構成されている場合、 gradを使用してサンプルごとの勾配を計算できます。
from functorch import vmap
batch_size , feature_size = 3 , 5
def model ( weights , feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
def compute_loss ( weights , example , target ):
y = model ( weights , example )
return (( y - target ) ** 2 ). mean () # MSELoss
weights = torch . randn ( feature_size , requires_grad = True )
examples = torch . randn ( batch_size , feature_size )
targets = torch . randn ( batch_size )
inputs = ( weights , examples , targets )
grad_weight_per_example = vmap ( grad ( compute_loss ), in_dims = ( None , 0 , 0 ))( * inputs ) vjp変換は、 func inputsに適用し、いくつかのcotangentsテンソルを与えられたVJPを計算する新しい関数を返します。
from functorch import vjp
outputs , vjp_fn = vjp ( func , inputs ); vjps = vjp_fn ( * cotangents ) jvp TransformsはJacobian-Vector-Productsを計算し、「フォワードモードAD」とも呼ばれます。他のほとんどの変換とは異なり、これは高次関数ではありませんが、 jvpと同様にfunc(inputs)の出力を返します。
from functorch import jvp
x = torch . randn ( 5 )
y = torch . randn ( 5 )
f = lambda x , y : ( x * y )
_ , output = jvp ( f , ( x , y ), ( torch . ones ( 5 ), torch . ones ( 5 )))
assert torch . allclose ( output , x + y ) jacrev Transformは、 x取り込んでTorch.sinのJacobianをxに対してxに対して、 xに対してjacobian of torch.sinを返す新しい関数を返します。
from functorch import jacrev
x = torch . randn ( 5 )
jacobian = jacrev ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) jacrevを使用してJacobianを計算します。これはVMAPで構成されて、バッチされたヤコビアンを生成できます。
x = torch . randn ( 64 , 5 )
jacobian = vmap ( jacrev ( torch . sin ))( x )
assert jacobian . shape == ( 64 , 5 , 5 ) jacfwd 、フォワードモードADを使用してJacobiansを計算するjacrevのドロップイン代替品です。
from functorch import jacfwd
x = torch . randn ( 5 )
jacobian = jacfwd ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected ) jacrev自分自身で作曲するか、 jacfwdヘシアンを生み出すことができます。
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hessian0 = jacrev ( jacrev ( f ))( x )
hessian1 = jacfwd ( jacrev ( f ))( x ) hessian 、 jacfwdとjacrevを組み合わせた利便性の機能です。
from functorch import hessian
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hess = hessian ( f )( x )また、 make_fxを使用して新しいコードとして結果をキャプチャするために、これらの変換をトレースすることもできます。また、NNCコンパイラとの実験的統合もあります(今のところCPUでのみ動作します!)。
from functorch import make_fx , grad
def f ( x ):
return torch . sin ( x ). sum ()
x = torch . randn ( 100 )
grad_f = make_fx ( grad ( f ))( x )
print ( grad_f . code )
def forward ( self , x_1 ):
sin = torch . ops . aten . sin ( x_1 )
sum_1 = torch . ops . aten . sum ( sin , None ); sin = None
cos = torch . ops . aten . cos ( x_1 ); x_1 = None
_tensor_constant0 = self . _tensor_constant0
mul = torch . ops . aten . mul ( _tensor_constant0 , cos ); _tensor_constant0 = cos = None
return mulnn.moduleのパラメーターおよび/またはバッファーに関して変換を実行する場合があります。これは、たとえば、次のように発生する可能性があります。
これに対する私たちの解決策は、NN.moduleを与えられたAPIで、機能と呼ばれるステートレスバージョンを作成します。
make_functional(model) modelとmodel.parameters()の機能バージョンを返しますmake_functional_with_buffers(model) modelとmodel.parameters()およびmodel.buffers()の関数版を返します。NN.Linearレイヤーを使用して、サンプルごとの勾配を計算する例は次のとおりです。
import torch
from functorch import make_functional , vmap , grad
model = torch . nn . Linear ( 3 , 3 )
data = torch . randn ( 64 , 3 )
targets = torch . randn ( 64 , 3 )
func_model , params = make_functional ( model )
def compute_loss ( params , data , targets ):
preds = func_model ( params , data )
return torch . mean (( preds - targets ) ** 2 )
per_sample_grads = vmap ( grad ( compute_loss ), ( None , 0 , 0 ))( params , data , targets )モデルのアンサンブルを作成している場合は、 combine_state_for_ensemble便利であることがわかります。
詳細については、Docs Webサイトを参照してください。
torch._C._functorch.dump_tensor :スタックトーチにディスパッチキーをダンプしますtorch._C._functorch._set_vmap_fallback_warning_enabled(False)
最後の状態では、デザインの詳細をアイロンをかけたら、これをPytorchに上流したいと思います。詳細を把握するには、あなたの助けが必要です。問題トラッカーで会話を開始するか、プロジェクトを試してみて、ユースケースを送ってください。
Funtctorchには、ライセンスファイルにあるように、BSDスタイルのライセンスがあります。
出版物でfuntctorchを使用している場合は、次のBibtexエントリを使用して引用してください。
@Misc { functorch2021 ,
author = { Horace He, Richard Zou } ,
title = { functorch: JAX-like composable function transforms for PyTorch } ,
howpublished = { url{https://github.com/pytorch/functorch} } ,
year = { 2021 }
}