Coder Social home page Coder Social logo

shubham0204 / scikit_learn_android_demo Goto Github PK

View Code? Open in Web Editor NEW
10.0 2.0 4.0 418 KB

An Android app that runs a scikit-learn model converted to the ONNX format

Home Page: https://towardsdatascience.com/deploying-scikit-learn-models-in-android-apps-with-onnx-b3adabe16bab

License: Apache License 2.0

Kotlin 100.00%
android android-application deployment kotlin-android machine-learning scikit-learn scikitlearn-machine-learning

scikit_learn_android_demo's Introduction

github_banner

Shubham Panchal

  • Open For Android or ML-based freelancing projects or internship opportunities
  • I'm exploring backend development with FastAPI, cross-platform app development with Compose Multiplatform (part of KMM), PyTorch and using Rust for deployment of ML models.

🙂 Reach me at

🌐 Profiles

profile for Shubham Panchal at Stack Overflow, Q&A for professional and enthusiast programmers

C/C++/Rust Projects

📱 Mobile ML Projects

Android Projects

✍️ Stories On Medium

Demystifying Mathematics - Shubham Panchal

Google Colab Notebooks - Tutorials

scikit_learn_android_demo's People

Contributors

shubham0204 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

scikit_learn_android_demo's Issues

Need to convert to .ort format

FYI: It’s not necessary to convert to .ort format when using the ‘full’ ONNX Runtime package onnxruntime-android - you can use the onnx model.

The conversion to .ort format is only necessary if using the smaller ‘mobile’ package onnxruntime-mobile, which has limited operators/types (based on popular dnn models used in mobile scenarios) to provide a smaller binary size. That package however does not include traditional ML operators that SciKit-Learn tends to use, so most likely it wouldn't be able to run a model that was converted from SKL.

Running a different type of variables

I need to run predict with the variabels as follows:
gender : Integer, age : Integer, weight : Integer, height : Interger, hours : Integer, years : Integer, bmi : float

How can I do that? Do I need to add more inputs variabel?

class CalculatorActivity : AppCompatActivity() {

    private lateinit var binding: ActivityCalculatorBinding
    private var isGenderSelected = false
    private var previousSelectedPosition = 0
    private var selectedGenderValue = 0 // 0 for Female, 1 for Male

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        binding = ActivityCalculatorBinding.inflate(layoutInflater)
        setContentView(binding.root)

        val textInputEditTextAge = binding.tedAge
        val placeholderAge = "... Years"
        textInputEditTextAge.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextAge.hint = placeholderAge

        val textInputEditTextWeight = binding.tedWeight
        val placeholderWeight = "... KG"
        textInputEditTextWeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextWeight.hint = placeholderWeight

        val textInputEditTextHeight = binding.tedHeight
        val placeholderHeight = "... CM"
        textInputEditTextHeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHeight.hint = placeholderHeight

        val textInputEditTextHours = binding.tedHours
        val placeholderHours = "... Hours"
        textInputEditTextHours.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHours.hint = placeholderHours

        val textInputEditTextYears = binding.tedYears
        val placeholderYears = "... Years"
        textInputEditTextYears.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextYears.hint = placeholderYears

        val gender = resources.getStringArray(R.array.Gender)

        val spinner = binding.genderSpinner
        if (spinner != null) {
            val adapter = ArrayAdapter(this,
                android.R.layout.simple_spinner_dropdown_item, gender)
            spinner.adapter = adapter

            spinner.onItemSelectedListener = object :
                AdapterView.OnItemSelectedListener {
                override fun onItemSelected(parent: AdapterView<*>, view: View, position: Int, id: Long) {
                    if (isGenderSelected) {
                        // A gender other than "-- Select your Gender --" has already been selected
                        if (position == 0) {
                            spinner.setSelection(previousSelectedPosition) // Set the spinner to the previous selected position
                        } else {
                            previousSelectedPosition = position // Update the previous selected position
                            selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                        }
                    } else {
                        isGenderSelected = true
                        previousSelectedPosition = position // Set the initial selected position
                        selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                    }
                }

                override fun onNothingSelected(parent: AdapterView<*>) {
                    // write code to perform some action
                }
            }
        }

        binding.btnPredict.setOnClickListener {
            val gender = selectedGenderValue.toFloat()
            val age = textInputEditTextAge.text.toString().toFloat()
            val weight = textInputEditTextWeight.text.toString().toFloat()
            val height = textInputEditTextHeight.text.toString().toFloat()
            val hours = textInputEditTextHours.text.toString().toFloat()
            val years = textInputEditTextYears.text.toString().toFloat()
            val bmi = (weight?.div(((height?.div(100))?.times((height?.div(100)!!))!!))).toString().toFloat()
            val inputs = floatArrayOf(gender, age, weight, height, hours, years, bmi)
            if (inputs != null) {
                val ortEnvironment = OrtEnvironment.getEnvironment()
                val ortSession = createORTSession(ortEnvironment)
                val output = runPrediction(inputs, ortSession, ortEnvironment)
                showOutputPopup(output)
            } else {
                Toast.makeText(this, "Please fill in all the inputs", Toast.LENGTH_LONG).show()
            }
        }
    }

    private fun createORTSession( ortEnvironment: OrtEnvironment) : OrtSession {
        val modelBytes = resources.openRawResource( R.raw.model1 ).readBytes()
        return ortEnvironment.createSession( modelBytes )
    }

    private fun runPrediction(input : FloatArray, ortSession: OrtSession , ortEnvironment: OrtEnvironment ) : Long {
        // Get the name of the input node
        val inputName = ortSession.inputNames?.iterator()?.next()
        // Make a FloatBuffer of the inputs
        val floatBufferInputs = FloatBuffer.wrap(input)
        // Create input tensor with floatBufferInputs of shape ( 1 , 1 )
        val inputTensor = OnnxTensor.createTensor(ortEnvironment, floatBufferInputs, longArrayOf(1, 7))
        // Run the model
        val results = ortSession.run( mapOf( inputName to inputTensor ) )
        // Fetch and return the results
        val output = results[0].value as LongArray
        return output[0]
    }

    fun showOutputPopup(output: Long) {
        // Inflate the custom layout for the popup
        val inflater = layoutInflater
        val popupView = inflater.inflate(R.layout.popup_output, null)

        // Find views within the custom layout
        val tvOutput = popupView.findViewById<TextView>(R.id.tvOutput)
        val btnClose = popupView.findViewById<Button>(R.id.btnClose)

        // Set the output text
        tvOutput.text = "Output is $output"

        // Create the dialog builder
        val builder = AlertDialog.Builder(this)
        builder.setView(popupView)

        // Create and show the dialog
        val dialog = builder.create()
        dialog.show()

        // Handle button click
        btnClose.setOnClickListener {
            dialog.dismiss() // Close the dialog when the button is clicked
        }
    }
}

More than one inputs.

I don't know if it's the correct way to predict with many inputs or not. I think there's something wrong in "val inputTensor = OnnxTensor.createTensor( ortEnvironment , floatBufferInputs , longArrayOf( 1, 1 ) )"

class CalculatorActivity : AppCompatActivity() {
    private lateinit var binding: ActivityCalculatorBinding
    private var isGenderSelected = false
    private var previousSelectedPosition = 0
    private var selectedGenderValue = 0 // 0 for Female, 1 for Male

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        binding = ActivityCalculatorBinding.inflate(layoutInflater)
        setContentView(binding.root)

        val textInputEditTextAge = binding.tedAge
        val placeholderAge = "... Years"
        textInputEditTextAge.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextAge.hint = placeholderAge

        val textInputEditTextWeight = binding.tedWeight
        val placeholderWeight = "... KG"
        textInputEditTextWeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextWeight.hint = placeholderWeight

        val textInputEditTextHeight = binding.tedHeight
        val placeholderHeight = "... CM"
        textInputEditTextHeight.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHeight.hint = placeholderHeight

        val textInputEditTextHours = binding.tedHours
        val placeholderHours = "... Hours"
        textInputEditTextHours.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextHours.hint = placeholderHours

        val textInputEditTextYears = binding.tedYears
        val placeholderYears = "... Years"
        textInputEditTextYears.inputType = InputType.TYPE_CLASS_NUMBER
        textInputEditTextYears.hint = placeholderYears

        val gender = resources.getStringArray(R.array.Gender)

        val spinner = binding.genderSpinner
        if (spinner != null) {
            val adapter = ArrayAdapter(this,
                android.R.layout.simple_spinner_dropdown_item, gender)
            spinner.adapter = adapter

            spinner.onItemSelectedListener = object :
                AdapterView.OnItemSelectedListener {
                override fun onItemSelected(parent: AdapterView<*>, view: View, position: Int, id: Long) {
                    if (isGenderSelected) {
                        // A gender other than "-- Select your Gender --" has already been selected
                        if (position == 0) {
                            spinner.setSelection(previousSelectedPosition) // Set the spinner to the previous selected position
                        } else {
                            previousSelectedPosition = position // Update the previous selected position
                            selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                        }
                    } else {
                        isGenderSelected = true
                        previousSelectedPosition = position // Set the initial selected position
                        selectedGenderValue = if (position == 1) 1 else 0 // Convert position to numeric value (0 or 1)
                    }
                }

                override fun onNothingSelected(parent: AdapterView<*>) {
                    // write code to perform some action
                }
            }
        }

        binding.btnPredict.setOnClickListener {
            val gender = selectedGenderValue.toFloat()
            val age = textInputEditTextAge.text.toString().toFloatOrNull()
            val weight = textInputEditTextWeight.text.toString().toFloatOrNull()
            val height = textInputEditTextHeight.text.toString().toFloatOrNull()
            val hours = textInputEditTextHours.text.toString().toFloatOrNull()
            val years = textInputEditTextYears.text.toString().toFloatOrNull()
            val bmi = (weight?.div(((height?.div(100))?.times((height?.div(100)!!))!!)))
            if (gender !=null && age != null && weight != null && height != null && hours != null && years != null && bmi != null){
                val ortEnvironment = OrtEnvironment.getEnvironment()
                val ortSession = createORTSession(ortEnvironment)
                val output = runPrediction(
                    gender, age, weight, height, hours, years, bmi,
                    ortSession, ortEnvironment
                )
                showOutputPopup(output)
            } else {
                Toast.makeText(this, "Please fill in all the inputs", Toast.LENGTH_LONG).show()
            }
        }
    }

    private fun createORTSession( ortEnvironment: OrtEnvironment) : OrtSession {
        val modelBytes = resources.openRawResource( R.raw.model ).readBytes()
        return ortEnvironment.createSession( modelBytes )
    }

    private fun runPrediction( genders: Float , age: Float , weight: Float , height: Float , hours: Float , years: Float , bmi : Float , ortSession: OrtSession , ortEnvironment: OrtEnvironment ) : Float {
        // Get the name of the input node
        val inputName = ortSession.inputNames?.iterator()?.next()
        // Make a FloatBuffer of the inputs
        val floatBufferInputs = FloatBuffer.wrap( floatArrayOf( genders, age, weight, height, hours, years, bmi ) )
        // Create input tensor with floatBufferInputs of shape ( 1 , 1 )
        val inputTensor = OnnxTensor.createTensor( ortEnvironment , floatBufferInputs , longArrayOf( 1, 1 ) )
        // Run the model
        val results = ortSession.run( mapOf( inputName to inputTensor ) )
        // Fetch and return the results
        val output = results[0].value as Array<FloatArray>
        return output[0][0]
    }

    fun showOutputPopup(output: Float) {
        // Inflate the custom layout for the popup
        val inflater = layoutInflater
        val popupView = inflater.inflate(R.layout.popup_output, null)

        // Find views within the custom layout
        val tvOutput = popupView.findViewById<TextView>(R.id.tvOutput)
        val btnClose = popupView.findViewById<Button>(R.id.btnClose)

        // Set the output text
        tvOutput.text = "Output is $output"

        // Create the dialog builder
        val builder = AlertDialog.Builder(this)
        builder.setView(popupView)

        // Create and show the dialog
        val dialog = builder.create()
        dialog.show()

        // Handle button click
        btnClose.setOnClickListener {
            dialog.dismiss() // Close the dialog when the button is clicked
        }
    }
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.