Coder Social home page Coder Social logo

harryzhang0415 / keras_to_tensorflow Goto Github PK

View Code? Open in Web Editor NEW

This project forked from amir-abdi/keras_to_tensorflow

0.0 0.0 0.0 143 KB

General code to convert a trained keras model into an inference tensorflow model

License: MIT License

Jupyter Notebook 65.53% Python 34.47%

keras_to_tensorflow's Introduction

keras_to_tensorflow

General code to convert a trained keras model into an inference tensorflow model

The notebook keras_to_tensorflow, is a sample code which loads a trained keras model, freezes the nodes (converts all tensorflow variables to tensorflow constants), and saves the inference graph and weights into a protobuf file (.pb). This file can then be used to deploy the trained model for inference. During freezing, other nodes of the network, which do not contribute the tensor that would contain the output predictions, are prunned. This results in a smaller, optimized network.

The code on how to freeze and save keras models in previous versions of tensorflow is also available. Back then, the freeze_graph tool (/tensorflow/python/tools/freeze_graph.py) was used to convert the variables into constants. This functionality is now handled by graph_util.convert_variables_to_constants

How to use

The keras model can be saved using model.save('file_name.h5') (check keras API documentation for details).

You can either use the iPython notebook (kears_to_tensorflow.ipnyb), or simply run the python script as below in the folder where your keras model is present:

python3 keras_to_tensorflow.py -input_model_file model.h5
python keras_to_tensorflow.py -input_model_file model.h5 

Try python3 keras_to_tensorflow.py --help for other input arguments.

Input arguments

  • num_output: this value has nothing to do with the number of classes, batch_size, etc., and it is mostly equal to 1. If the network is a multi-stream network (forked network with multiple outputs), set the value to the number of outputs.

  • quantize: if set to True, use the quantize feature of Tensorflow (https://www.tensorflow.org/performance/quantization) [default: False]

  • use_theano: Thaeno and Tensorflow implement convolution in different ways. When using Keras with Theano backend, the order is set to 'channels_first'. This feature is not fully tested, and doesn't work with quantizization [default: False]

  • input_fld: directory holding the keras weights file [default: .]

  • output_fld: destination directory to save the tensorflow files [default: .]

  • input_model_file: name of the input weight file [default: 'model.h5']

  • output_model_file: name of the output weight file [default: args.input_model_file + '.pb']

  • graph_def: if set to True, will write the graph definition as an ascii file [default: False]

  • output_graphdef_file: if graph_def is set to True, the file name of the graph definition [default: model.ascii]

  • output_node_prefix: the prefix to use for output nodes. [default: output_node]

Not tested features:

Theano support is not yet fully tested. I don't assume loading a theano model in Keras and saving it as a Tensorflow model would work without any proper conversion between the two. The feature is committed for theano enthusiasts to test.

Dependencies

  • Keras
  • Tensorflow
  • argparse
  • pathlib

keras_to_tensorflow's People

Contributors

amir-abdi avatar aswathkiruba avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.