Coder Social home page Coder Social logo

pytorch-onnx-tensorflow-pb's Introduction

Converting A PyTorch Model to Tensorflow pb using ONNX

[email protected]

有必要说在前面,避免后来者陷坑:

ONNX本来是Facebook联合AWS对抗Tensorflow的,所以注定ONNX-TF这件事是奸情,这是ONNX和TF偷情的行为,两个平台都不会为他们背书;Pytorch和Tensorflow各自在独立演变,动态图和静态图优化两者不会停战。如果你在尝试转模型这件事情,觉得你有必要考虑:1.让服务部署平台支持Pytorch; 2.转训练平台到TF; 3.这件事是一锤子买卖,干完就不要再倒腾了。

本Demo所使用模型来自:https://github.com/cinastanbean/Pytorch-Multi-Task-Multi-class-Classification

[TOC]

1. Pre-installation

Version Info

pytorch                   0.4.0           py27_cuda0.0_cudnn0.0_1    pytorch
torchvision               0.2.1                    py27_1    pytorch
tensorflow                1.8.0                     <pip>
onnx                      1.2.2                     <pip>
onnx-tf                   1.1.2                     <pip> 

注意:

  1. ONNX1.1.2版本太低会引发BatchNormalization错误,当前pip已经支持1.3.0版本;也可以考虑源码安装 pip install -U git+https://github.com/onnx/onnx.git@master
  2. 本实验验证ONNX1.2.2版本可正常运行
  3. onnx-tf采用源码安装;要求 Tensorflow>=1.5.0.;

2. 转换过程

2.1 Step 1.2.3.

pipeline: pytorch model --> onnx modle --> tensorflow graph pb.

# step 1, load pytorch model and export onnx during running.
    modelname = 'resnet18'
    weightfile = 'models/model_best_checkpoint_resnet18.pth.tar'
    modelhandle = DIY_Model(modelname, weightfile, class_numbers)
    model = modelhandle.model
    #model.eval() # useless
    dummy_input = Variable(torch.randn(1, 3, 224, 224)) # nchw
    onnx_filename = os.path.split(weightfile)[-1] + ".onnx"
    torch.onnx.export(model, dummy_input,
                      onnx_filename,
                      verbose=True)
    
    # step 2, create onnx_model using tensorflow as backend. check if right and export graph.
    onnx_model = onnx.load(onnx_filename)
    tf_rep = prepare(onnx_model, strict=False)
    # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
    # Reference https://github.com/onnx/onnx-tensorflow/issues/167
    #tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
    image = Image.open('pants.jpg')
    # debug, here using the same input to check onnx and tf.
    output_pytorch, img_np = modelhandle.process(image)
    print('output_pytorch = {}'.format(output_pytorch))
    output_onnx_tf = tf_rep.run(img_np)
    print('output_onnx_tf = {}'.format(output_onnx_tf))
    # onnx --> tf.graph.pb
    tf_pb_path = onnx_filename + '_graph.pb'
    tf_rep.export_graph(tf_pb_path)
    
    # step 3, check if tf.pb is right.
    with tf.Graph().as_default():
        graph_def = tf.GraphDef()
        with open(tf_pb_path, "rb") as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name="")
        with tf.Session() as sess:
            #init = tf.initialize_all_variables()
            init = tf.global_variables_initializer()
            #sess.run(init)
            
            # print all ops, check input/output tensor name.
            # uncomment it if you donnot know io tensor names.
            '''
            print('-------------ops---------------------')
            op = sess.graph.get_operations()
            for m in op:
                print(m.values())
            print('-------------ops done.---------------------')
            '''

            input_x = sess.graph.get_tensor_by_name("0:0") # input
            outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
            outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
            output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np})
            #output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)})
            print('output_tf_pb = {}'.format(output_tf_pb))

2.2 Verification

确保输出结果一致

output_pytorch = [array([ 2.5359073 , -1.4261041 , -5.2394    , -0.62402934,  4.7426634 ], dtype=float32), array([ 7.6249304,  5.1203837,  1.8118637,  1.5143847, -4.9409146, 1.1695148, -6.2375665, -1.6033885, -1.4286405, -2.964429 ], dtype=float32)]
      
output_onnx_tf = Outputs(_0=array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269,  4.7426634]], dtype=float32), _1=array([[ 7.6249285,  5.12038  ,  1.811865 ,  1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32))
      
output_tf_pb = [array([[ 2.5359051, -1.4261056, -5.239397 , -0.6240269,  4.7426634]], dtype=float32), array([[ 7.6249285,  5.12038  ,  1.811865 ,  1.5143874, -4.940915 , 1.1695154, -6.237564 , -1.6033876, -1.4286422, -2.964428 ]], dtype=float32)]

独立TF验证程序

def get_img_np_nchw(filename):
    try:
        image = Image.open(filename).convert('RGB').resize((224, 224))
        miu = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        #miu = np.array([0.5, 0.5, 0.5])
        #std = np.array([0.22, 0.22, 0.22])
        # img_np.shape = (224, 224, 3)
        img_np = np.array(image, dtype=float) / 255.
        r = (img_np[:,:,0] - miu[0]) / std[0]
        g = (img_np[:,:,1] - miu[1]) / std[1]
        b = (img_np[:,:,2] - miu[2]) / std[2]
        img_np_t = np.array([r,g,b])
        img_np_nchw = np.expand_dims(img_np_t, axis=0)
        return img_np_nchw
    except:
        print("RuntimeError: get_img_np_nchw({}).".format(filename))
        # NoneType
    

if __name__ == '__main__':
    
    tf_pb_path = 'model_best_checkpoint_resnet18.pth.tar.onnx_graph.pb'
    
    filename = 'pants.jpg'
    img_np_nchw = get_img_np_nchw(filename)
    
    # step 3, check if tf.pb is right.
    with tf.Graph().as_default():
        graph_def = tf.GraphDef()
        with open(tf_pb_path, "rb") as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name="")
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            #init = tf.initialize_all_variables()
            sess.run(init)
            
            # print all ops, check input/output tensor name.
            # uncomment it if you donnot know io tensor names.
            '''
            print('-------------ops---------------------')
            op = sess.graph.get_operations()
            for m in op:
                print(m.values())
            print('-------------ops done.---------------------')
            '''

            input_x = sess.graph.get_tensor_by_name("0:0") # input
            outputs1 = sess.graph.get_tensor_by_name('add_1:0') # 5
            outputs2 = sess.graph.get_tensor_by_name('add_3:0') # 10
            output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:img_np_nchw})
            print('output_tf_pb = {}'.format(output_tf_pb))

3. Related Info

3.1 ONNX

Open Neural Network Exchange https://github.com/onnx https://onnx.ai/

The ONNX exporter is a ==trace-based== exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. Limitations

https://github.com/onnx/tensorflow-onnx https://github.com/onnx/onnx-tensorflow

3.2 Microsoft/MMdnn

当前网络没有调通 https://github.com/Microsoft/MMdnn/blob/master/mmdnn/conversion/pytorch/README.md

Reference

  1. Open Neural Network Exchange https://github.com/onnx
  2. Exporting model from PyTorch to ONNX
  3. Importing ONNX models to Tensorflow(ONNX)
  4. Tensorflow + tornado服务
  5. graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())
  6. A Tool Developer's Guide to TensorFlow Model Files
  7. TensorFlow学习笔记:Retrain Inception_v3

pytorch-onnx-tensorflow-pb's People

Contributors

cinastanbean avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-onnx-tensorflow-pb's Issues

onnx和pytorch同一输入不同输出?

你好!我把pytorch模型转换成onnx后,用同一个输入,两个模型输出的结果不一致该怎么办呢?torch.onnx.export中的参数model应该是没有问题的,会是因为其他参数吗?

RuntimeError: Resize coordinate_transformation_mode=pytorch_half_pixel is not supported in Tensorflow.

onnx转pb模型的时候,报这个错误:
Traceback (most recent call last):
File "onnx2pb.py", line 45, in
onnx2pb_2(onnx_input_path)
File "onnx2pb.py", line 14, in onnx2pb_2
tf_rep = prepare(model)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/backend.py", line 66, in prepare
return cls.onnx_model_to_tensorflow_rep(model, strict)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/backend.py", line 86, in onnx_model_to_tensorflow_rep
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/backend.py", line 147, in _onnx_graph_to_tensorflow_rep
strict=strict)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/backend.py", line 252, in _onnx_node_to_tensorflow_op
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/handlers/handler.py", line 59, in handle
cls.args_check(node, **kwargs)
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/handlers/backend/resize.py", line 89, in args_check
"Tensorflow")
File "/home/fffan/下载/onnx-tensorflow-rel-1.6.0/onnx_tf/common/exception.py", line 49, in call
raise self._func(self.get_message(op, framework))
RuntimeError: Resize coordinate_transformation_mode=pytorch_half_pixel is not supported in Tensorflow.
请问这个问题该怎么解决啊

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.