pytorch2keras
1.0.0
Pytorch到Keras模型轉換器。
pip install pytorch2keras
要正確使用轉換器,請對您的~/.keras/keras.json進行更改:
...
"backend" : " tensorflow " ,
"image_data_format" : " channels_first " ,
... 有關tensorflow.js格式的正確轉換,請使用新的標誌names='short' 。
這是如何獲得TensorFlow.js模型的簡短指令:
k_model = pytorch_to_keras ( model , input_var , [( 10 , 32 , 32 ,)], verbose = True , names = 'short' ) tensorflowjs_converter轉換它,但有時不起作用。作為替代方案,您可能會獲得TensorFlow圖並將其保存為冷凍模型: # Function below copied from here:
# https://stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb
def freeze_session ( session , keep_var_names = None , output_names = None , clear_devices = True ):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow . python . framework . graph_util import convert_variables_to_constants
graph = session . graph
with graph . as_default ():
freeze_var_names =
list ( set ( v . op . name for v in tf . global_variables ()). difference ( keep_var_names or []))
output_names = output_names or []
output_names += [ v . op . name for v in tf . global_variables ()]
input_graph_def = graph . as_graph_def ()
if clear_devices :
for node in input_graph_def . node :
node . device = ""
frozen_graph = convert_variables_to_constants ( session , input_graph_def ,
output_names , freeze_var_names )
return frozen_graph
from keras import backend as K
import tensorflow as tf
frozen_graph = freeze_session ( K . get_session (),
output_names = [ out . op . name for out in k_model . outputs ])
tf . train . write_graph ( frozen_graph , "." , "my_model.pb" , as_text = False )
print ([ i for i in k_model . outputs ])my_model.pb轉換為tfjs模型了: tensorflowjs_converter
--input_format=tf_frozen_model
--output_node_names= ' TANHTObs/Tanh '
my_model.pb
model_tfjs const MODEL_URL = `model_tfjs/tensorflowjs_model.pb` ;
const WEIGHTS_URL = `model_tfjs/weights_manifest.json` ;
const model = await tf . loadFrozenModel ( MODEL_URL , WEIGHTS_URL ) ; 這是Pytorch圖對KERAS(TensorFlow Backend)模型的轉換器。
首先,我們需要加載(或創建)有效的Pytorch模型:
class TestConv2d ( nn . Module ):
"""
Module for Conv2d testing
"""
def __init__ ( self , inp = 10 , out = 16 , kernel_size = 3 ):
super ( TestConv2d , self ). __init__ ()
self . conv2d = nn . Conv2d ( inp , out , stride = 1 , kernel_size = kernel_size , bias = True )
def forward ( self , x ):
x = self . conv2d ( x )
return x
model = TestConv2d ()
# load weights here
# model.load_state_dict(torch.load(path_to_weights.pth))下一步 - 創建一個具有正確形狀的虛擬變量:
input_np = np . random . uniform ( 0 , 1 , ( 1 , 10 , 32 , 32 ))
input_var = Variable ( torch . FloatTensor ( input_np ))我們使用虛擬變量來跟踪模型(使用JIT.TRACE):
from pytorch2keras import pytorch_to_keras
# we should specify shape of the input tensor
k_model = pytorch_to_keras ( model , input_var , [( 10 , 32 , 32 ,)], verbose = True ) 您還可以將h和w尺寸設置為none,以使您的模型形狀不平衡(例如完全卷積的netowrk):
from pytorch2keras . converter import pytorch_to_keras
# we should specify shape of the input tensor
k_model = pytorch_to_keras ( model , input_var , [( 10 , None , None ,)], verbose = True ) 就這樣!如果所有模塊都已正確轉換,則KERAS模型將存儲在k_model變量中。
這是pytorch2keras模塊的唯一方法pytorch_to_keras 。
def pytorch_to_keras (
model , args , input_shapes = None ,
change_ordering = False , verbose = False , name_policy = None ,
):選項:
model - 用於轉換的pytorch模型(nn.模塊);args具有適當形狀的虛擬變量列表;input_shapes (實驗)列表,帶有輸入的形狀;change_ordering (實驗)布爾值,如果啟用,轉換器將嘗試將BCHW更改為BHWCverbose - 布爾值,詳細的轉換日誌name_policy (實驗)從[ keep , short , random ]中選擇。選擇器設置目標層命名策略。 啟用設定:
常數
卷積:
元素:
線性
正常化:
池:
查看tests目錄。
該軟件由麻省理工學院許可證涵蓋。