pytorch2keras
1.0.0
Pytorch到Keras模型转换器。
pip install pytorch2keras
要正确使用转换器,请对您的~/.keras/keras.json进行更改:
...
"backend" : " tensorflow " ,
"image_data_format" : " channels_first " ,
... 有关tensorflow.js格式的正确转换,请使用新的标志names='short' 。
这是如何获得TensorFlow.js模型的简短指令:
k_model = pytorch_to_keras ( model , input_var , [( 10 , 32 , 32 ,)], verbose = True , names = 'short' ) tensorflowjs_converter转换它,但有时不起作用。作为替代方案,您可能会获得TensorFlow图并将其保存为冷冻模型: # Function below copied from here:
# https://stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb
def freeze_session ( session , keep_var_names = None , output_names = None , clear_devices = True ):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow . python . framework . graph_util import convert_variables_to_constants
graph = session . graph
with graph . as_default ():
freeze_var_names =
list ( set ( v . op . name for v in tf . global_variables ()). difference ( keep_var_names or []))
output_names = output_names or []
output_names += [ v . op . name for v in tf . global_variables ()]
input_graph_def = graph . as_graph_def ()
if clear_devices :
for node in input_graph_def . node :
node . device = ""
frozen_graph = convert_variables_to_constants ( session , input_graph_def ,
output_names , freeze_var_names )
return frozen_graph
from keras import backend as K
import tensorflow as tf
frozen_graph = freeze_session ( K . get_session (),
output_names = [ out . op . name for out in k_model . outputs ])
tf . train . write_graph ( frozen_graph , "." , "my_model.pb" , as_text = False )
print ([ i for i in k_model . outputs ])my_model.pb转换为tfjs模型了: tensorflowjs_converter
--input_format=tf_frozen_model
--output_node_names= ' TANHTObs/Tanh '
my_model.pb
model_tfjs const MODEL_URL = `model_tfjs/tensorflowjs_model.pb` ;
const WEIGHTS_URL = `model_tfjs/weights_manifest.json` ;
const model = await tf . loadFrozenModel ( MODEL_URL , WEIGHTS_URL ) ; 这是Pytorch图对KERAS(TensorFlow Backend)模型的转换器。
首先,我们需要加载(或创建)有效的Pytorch模型:
class TestConv2d ( nn . Module ):
"""
Module for Conv2d testing
"""
def __init__ ( self , inp = 10 , out = 16 , kernel_size = 3 ):
super ( TestConv2d , self ). __init__ ()
self . conv2d = nn . Conv2d ( inp , out , stride = 1 , kernel_size = kernel_size , bias = True )
def forward ( self , x ):
x = self . conv2d ( x )
return x
model = TestConv2d ()
# load weights here
# model.load_state_dict(torch.load(path_to_weights.pth))下一步 - 创建一个具有正确形状的虚拟变量:
input_np = np . random . uniform ( 0 , 1 , ( 1 , 10 , 32 , 32 ))
input_var = Variable ( torch . FloatTensor ( input_np ))我们使用虚拟变量来跟踪模型(使用JIT.TRACE):
from pytorch2keras import pytorch_to_keras
# we should specify shape of the input tensor
k_model = pytorch_to_keras ( model , input_var , [( 10 , 32 , 32 ,)], verbose = True ) 您还可以将h和w尺寸设置为none,以使您的模型形状不平衡(例如完全卷积的netowrk):
from pytorch2keras . converter import pytorch_to_keras
# we should specify shape of the input tensor
k_model = pytorch_to_keras ( model , input_var , [( 10 , None , None ,)], verbose = True ) 就这样!如果所有模块都已正确转换,则KERAS模型将存储在k_model变量中。
这是pytorch2keras模块的唯一方法pytorch_to_keras 。
def pytorch_to_keras (
model , args , input_shapes = None ,
change_ordering = False , verbose = False , name_policy = None ,
):选项:
model - 用于转换的pytorch模型(nn.模块);args具有适当形状的虚拟变量列表;input_shapes (实验)列表,带有输入的形状;change_ordering (实验)布尔值,如果启用,转换器将尝试将BCHW更改为BHWCverbose - 布尔值,详细的转换日志name_policy (实验)从[ keep , short , random ]中选择。选择器设置目标层命名策略。 激活:
常数
卷积:
元素:
线性
正常化:
池:
查看tests目录。
该软件由麻省理工学院许可证涵盖。