Coder Social home page Coder Social logo

tflite-mnist-android's Introduction

MNIST with TensorFlow Lite on Android

Open In Colab

This project demonstrates how to use TensorFlow Lite on Android for handwritten digits classification from MNIST.

Prebuilt APK can be downloaded from here.

How to build from scratch

Environment

  • Python 3.7
  • tensorflow 2.3.0
  • tensorflow-datasets 3.2.1

Step 1. Train and convert the model to TensorFlow Lite FlatBuffer

Run all the code cells in model.ipynb.

  • If you are running Jupyter Notebook locally, a mnist.tflite file will be saved to the project directory.
  • If you are running the notebook in Google Colab, a mnist.tflite file will be downloaded.

Step 2. Build Android app

Copy the mnist.tflite generated in Step 1 to /android/app/src/main/assets, then build and run the app. A prebuilt APK can be downloaded from here.

The Classifer reads the mnist.tflite from assets directory and loads it into an Interpreter for inference. The Interpreter provides an interface between TensorFlow Lite model and Java code.

If you are building your own app, remember to add the following code to build.gradle to prevent compression for model files.

aaptOptions {
    noCompress "tflite"
    noCompress "lite"
}

Credits

tflite-mnist-android's People

Contributors

nex3z avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

tflite-mnist-android's Issues

batch_normalization error

thank you for sharing your code,have you tried this:

def batch_norm_lite(x, train=True, bn_decay=0.5,epsilon = 0.001,name='bn'):
    is_training = tf.convert_to_tensor(train,dtype='bool',name='is_training')
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]

    axis = list(range(len(x_shape) - 1))

    beta = tf.get_variable(name+'_beta', params_shape, initializer=tf.zeros_initializer())
    gamma = tf.get_variable(name+'_gamma', params_shape, initializer=tf.ones_initializer())

    moving_mean = tf.get_variable(name+'_moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
    moving_variance = tf.get_variable(name+'_moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)

    # These ops will only be preformed when training.
    mean, variance = tf.nn.moments(x, axis)
    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, bn_decay)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, bn_decay)
    tf.add_to_collection(name+'_update_moving_mean', update_moving_mean)
    tf.add_to_collection(name+'_update_moving_variance', update_moving_variance)


    mean, variance = control_flow_ops.cond(
        is_training, lambda: (mean, variance),
        lambda: (moving_mean, moving_variance))

    return tf.nn.batch_normalization(x, mean, variance, beta, gamma, epsilon,name=name)

def inference(input_tensor, regularizer=None):
    with tf.variable_scope("layer_1_conv"):
        conv_1_weight = get_weight([CONV_1_SIZE, CONV_1_SIZE, IMAGE_CHANNEL_NUM, CONV_1_DEPTH])
        conv_1_bias = get_bias([CONV_1_DEPTH])
        conv_1 = conv2d(input_tensor, conv_1_weight, stride=1)
        bn = batch_norm_lite(tf.nn.bias_add(conv_1, conv_1_bias))
        conv_1_activation = tf.nn.relu(bn)

add a batch_normalization layer before relu. Because this is so common, this code cannot be turned into tflite,how to solve this problem?

what is the function of convertToGreyScale?

As follow:
private float convertToGreyScale(int color) { return (((color >> 16) & 0xFF) + ((color >> 8) & 0xFF) + (color & 0xFF)) / 3.0f / 255.0f; }

I need your help,thanks!

图片shape[?,28,28,1]

@nex3z ,你好,界面写好数字,获取bitmap后,Classifier.java通过convertToGreyScale转为灰度图,
那图片的shape如何从[28,28,1]转为[1,28,28,1]的
麻烦您解释下,谢谢

请问如何进一步提高准确度???

您好,虽然训练集的准确度看起来很高,但是发现体验并不是很好,像7和9,我的写法基本就没法被正确识别,请问如何进一步提高准确度???

Checkpoint not found

If I train the model locally, and then convert it to *.tflite, it works fine. However, if I use the pre-trained checkpoint, I always get the "Checkpoint not found" error. What's wrong?

您好

我稍稍改了下UI而已,但是我编译出来的app现在只能识别3了……orz

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.