為什麼要功能棒? |安裝指南|轉換|文檔|未來計劃
該圖書館目前正在繁重的開發中 - 如果您對要涵蓋的API或用例有建議,請打開GitHub問題或伸出援手。我們很想听聽您如何使用圖書館。
functorch是Pytorch類似JAX的合併函數轉換。
它的目的是提供可組合的vmap和grad變換,可與Pytorch模塊和Pytorch Autograd一起使用,並具有良好的急切模式性能。
此外,還可以使用FX追踪這些轉換,以便提前捕獲這些變換的結果。這將使我們能夠編譯VMAP或GRAD的結果以提高性能。
如今,在Pytorch中有許多用例很棘手:
組成vmap , grad , vjp和jvp變換使我們能夠表達上述情況,而無需為每個系統設計一個單獨的子系統。合併函數轉換的想法來自JAX框架。
有兩種安裝函數的方法:
我們建議先嘗試使用功能棒。
按照本Colab筆記本中的說明進行操作
截至9/21/2022, functorch與每晚的Pytorch二進制旁安裝。請安裝預覽(夜間)pytorch二進制;有關說明,請參見https://pytorch.org/。
完成此操作後,請在Python中快速進行理智檢查:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())從9/21/2022開始, functorch與Pytorch一起安裝,並在Pytorch源樹中。請從源安裝pytorch,然後,您將能夠import functorch 。
嘗試進行一些測試以確保一切都可以:
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -vAotautograd還有一些其他可選要求。您可以通過:
pip install networkx要運行函數測試,請安裝我們的測試依賴項( expecttest , pyyaml )。
在此處遵循說明
先決條件:安裝Pytorch
pip install functorch最後,在Python中進行快速的理智檢查:
import torch
from functorch import vmap
x = torch . randn ( 3 )
y = vmap ( torch . sin )( x )
assert torch . allclose ( y , x . sin ())現在,我們支持以下轉換:
grad , vjp , jvp ,jacrev , jacfwd , hessianvmap此外,我們還有一些用於使用Pytorch模塊的實用程序。
make_functional(model)make_functional_with_buffers(model)注意: vmap對可以使用的代碼施加限制。有關更多詳細信息,請閱讀其docstring。
vmap(func)(*inputs)是一個轉換,可為func中的所有張量操作增加一個尺寸。 vmap(func)返回一個新功能,該功能在inputs中每個張量的某個維度(默認值:0)上映射func 。
vmap對於隱藏批處理尺寸很有用:一個人可以編寫一個在示例上運行的函數func ,然後將其提升到可以使用vmap(func)批量示例的函數,從而帶來更簡單的建模體驗:
from functorch import vmap
batch_size , feature_size = 3 , 5
weights = torch . randn ( feature_size , requires_grad = True )
def model ( feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
examples = torch . randn ( batch_size , feature_size )
result = vmap ( model )( examples )grad(func)(*inputs)假設func返回單元素張量。它計算出輸入到inputs[0]的輸出的梯度。
from functorch import grad
x = torch . randn ([])
cos_x = grad ( lambda x : torch . sin ( x ))( x )
assert torch . allclose ( cos_x , x . cos ())
# Second-order gradients
neg_sin_x = grad ( grad ( lambda x : torch . sin ( x )))( x )
assert torch . allclose ( neg_sin_x , - x . sin ())當由vmap組成時,可用於計算每樣本grad :
from functorch import vmap
batch_size , feature_size = 3 , 5
def model ( weights , feature_vec ):
# Very simple linear model with activation
assert feature_vec . dim () == 1
return feature_vec . dot ( weights ). relu ()
def compute_loss ( weights , example , target ):
y = model ( weights , example )
return (( y - target ) ** 2 ). mean () # MSELoss
weights = torch . randn ( feature_size , requires_grad = True )
examples = torch . randn ( batch_size , feature_size )
targets = torch . randn ( batch_size )
inputs = ( weights , examples , targets )
grad_weight_per_example = vmap ( grad ( compute_loss ), in_dims = ( None , 0 , 0 ))( * inputs ) vjp變換將func應用於inputs ,並返回一個新功能,該功能在給定某些cotangents張量的情況下計算VJP。
from functorch import vjp
outputs , vjp_fn = vjp ( func , inputs ); vjps = vjp_fn ( * cotangents ) jvp轉換計算Jacobian-vector-Marotucts,也稱為“前向模式AD”。與大多數其他變換不同,它不是高階函數,但它返回了func(inputs)的輸出以及jvp s。
from functorch import jvp
x = torch . randn ( 5 )
y = torch . randn ( 5 )
f = lambda x , y : ( x * y )
_ , output = jvp ( f , ( x , y ), ( torch . ones ( 5 ), torch . ones ( 5 )))
assert torch . allclose ( output , x + y ) jacrev變換返回了一個新功能,該功能使用x ,並使用反向模式AD返回x的torch.sin 。
from functorch import jacrev
x = torch . randn ( 5 )
jacobian = jacrev ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected )使用jacrev計算Jacobian。這可以與VMAP組成以生產批處理的雅各佈人:
x = torch . randn ( 64 , 5 )
jacobian = vmap ( jacrev ( torch . sin ))( x )
assert jacobian . shape == ( 64 , 5 , 5 ) jacfwd是使用前向模式AD計算Jacobians的jacrev的替代品:
from functorch import jacfwd
x = torch . randn ( 5 )
jacobian = jacfwd ( torch . sin )( x )
expected = torch . diag ( torch . cos ( x ))
assert torch . allclose ( jacobian , expected )用自己或jacfwd組成的jacrev可以生產黑森州:
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hessian0 = jacrev ( jacrev ( f ))( x )
hessian1 = jacfwd ( jacrev ( f ))( x ) hessian是一個便利功能,結合了jacfwd和jacrev :
from functorch import hessian
def f ( x ):
return x . sin (). sum ()
x = torch . randn ( 5 )
hess = hessian ( f )( x )我們還可以追踪這些轉換,以便使用make_fx作為新代碼捕獲結果。 NNC編譯器也有實驗集成(目前僅在CPU上起作用!)。
from functorch import make_fx , grad
def f ( x ):
return torch . sin ( x ). sum ()
x = torch . randn ( 100 )
grad_f = make_fx ( grad ( f ))( x )
print ( grad_f . code )
def forward ( self , x_1 ):
sin = torch . ops . aten . sin ( x_1 )
sum_1 = torch . ops . aten . sum ( sin , None ); sin = None
cos = torch . ops . aten . cos ( x_1 ); x_1 = None
_tensor_constant0 = self . _tensor_constant0
mul = torch . ops . aten . mul ( _tensor_constant0 , cos ); _tensor_constant0 = cos = None
return mul有時,您可能需要對NN.模塊的參數和/或緩衝區進行轉換。例如,這可能發生在:
我們現在對此的解決方案是一個API,鑑於nn.module,它創建了它的無狀態版本,可以稱為函數。
make_functional(model)返回model的功能版本和model.parameters()make_functional_with_buffers(model)返回model的功能版本和model.parameters()和model.buffers() 。這是一個示例,我們使用nn.linear層來計算每樣本級別:
import torch
from functorch import make_functional , vmap , grad
model = torch . nn . Linear ( 3 , 3 )
data = torch . randn ( 64 , 3 )
targets = torch . randn ( 64 , 3 )
func_model , params = make_functional ( model )
def compute_loss ( params , data , targets ):
preds = func_model ( params , data )
return torch . mean (( preds - targets ) ** 2 )
per_sample_grads = vmap ( grad ( compute_loss ), ( None , 0 , 0 ))( params , data , targets )如果您要製作模型的合奏,則可能會發現combine_state_for_ensemble有用。
有關更多文檔,請參見我們的文檔網站。
torch._C._functorch.dump_tensor :在Stack torch._C._functorch._set_vmap_fallback_warning_enabled(False)上轉儲調度鍵。
最終狀態,一旦我們淘汰了設計細節,我們希望將其上游進入Pytorch。為了弄清楚詳細信息,我們需要您的幫助 - 請通過在問題跟踪器中啟動對話或嘗試我們的項目來向我們發送您的用例。
如許可證文件中所示,FOUNDORCH具有BSD式許可。
如果您在出版物中使用Foundorch,請使用以下Bibtex條目引用它。
@Misc { functorch2021 ,
author = { Horace He, Richard Zou } ,
title = { functorch: JAX-like composable function transforms for PyTorch } ,
howpublished = { url{https://github.com/pytorch/functorch} } ,
year = { 2021 }
}