¿Cómo crear un modelo personalizado para Android usando TensorFlow?

Tensorflow es una biblioteca de código abierto para el aprendizaje automático. En Android, tenemos un poder de cómputo y recursos limitados. Por lo tanto, usamos TensorFlow light, que está diseñado específicamente para funcionar en dispositivos con energía limitada. En esta publicación, veremos un ejemplo de clasificación llamado conjunto de datos iris. El conjunto de datos contiene 3 clases de 50 instancias cada una, donde cada clase se refiere al tipo de planta de iris.

Información de atributos:

  1. longitud del sépalo en cm
  2. anchura del sépalo en cm
  3. longitud del pétalo en cm
  4. ancho de pétalo en cm

Según la información proporcionada en la entrada, predeciremos si la planta es Iris Setosa , Iris Versicolour o Iris Virginica . Puede consultar este enlace para obtener más información.

Implementación paso a paso

Paso 1:

Descargue el conjunto de datos de iris ( nombre de archivo: iris.data ) desde este enlace ( https://archive.ics.uci.edu/ml/machine-learning-databases/iris/ ).

Paso 2:

Cree un nuevo archivo de python con un nombre iris en el cuaderno de Jupyter. Coloque el archivo iris.data en el mismo directorio donde reside iris.ipynb. Copie el código siguiente en el archivo del cuaderno de Jupyter.

iris.ipynb

Python

import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical
 
# reading the csb into data frame
df = pd.read_csv('iris.data')
 
# specifying the columns values into x and y variable
# iloc range based selecting 0 to 4 (4) values
X = df.iloc[:, :4].values
y = df.iloc[:, 4].values
 
# normalizing labels
le = LabelEncoder()
 
# performing fit and transform data on y
y = le.fit_transform(y)
 
y = to_categorical(y)
 
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
 
model = Sequential()
 
# input layer
# passing number neurons =64
# relu activation
# shape of neuron 4
model.add(Dense(64, activation='relu', input_shape=[4]))
 
# processing layer
# adding another denser layer of size 64
model.add(Dense(64))
 
# creating 3 output neuron
model.add(Dense(3, activation='softmax'))
 
 
# compiling model
model.compile(optimizer='sgd', loss='categorical_crossentropy',
              metrics=['acc'])
 
# training the model for fixed number of iterations (epoches)
model.fit(X, y, epochs=200)
 
from tensorflow import lite
converter = lite.TFLiteConverter.from_keras_model(model)
 
tfmodel = converter.convert()
 
open('iris.tflite', 'wb').write(tfmodel)

Paso 3:

Después de ejecutar la línea open(‘iris.tflite’,’wb’).write(tfmodel) se creará un nuevo archivo llamado iris.tflite en el mismo directorio donde reside iris.data. 

A) Abra Android Studio. Cree un nuevo proyecto kotlin-android. (Puede consultar aquí para crear un proyecto). 

B) Haga clic con el botón derecho en la aplicación > Nuevo > Otro > Modelo TensorFlow Lite 

C) Haga clic en el icono de la carpeta. 

D) Navegue hasta el archivo iris.tflite 

E) Haga clic en Aceptar

F) Su modelo se verá así después de hacer clic en el acabado. (Puede tardar un poco en cargar). 

Copie el código y péguelo en el detector de clics de un botón en MainActivity.kt. (Se muestra a continuación).

Paso 5: Cree un diseño XML para la predicción

Vaya a la aplicación > res > diseño > actividad_principal.xml y agregue el siguiente código a ese archivo. A continuación se muestra el código para el archivo   activity_main.xml .

XML

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">
 
    <ScrollView
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:layout_marginBottom="50dp">
 
        <LinearLayout
            android:layout_width="match_parent"
            android:layout_height="match_parent"
            android:orientation="vertical">
           
            <!-- creating  edittexts for input-->
            <EditText
                android:id="@+id/tf1"
                android:layout_width="175dp"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="70dp"
                android:ems="10"
                android:inputType="numberDecimal" />
 
            <EditText
                android:id="@+id/tf2"
                android:layout_width="175dp"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="20dp"
                android:ems="10"
                android:inputType="numberDecimal" />
 
            <EditText
                android:id="@+id/tf3"
                android:layout_width="175dp"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="20dp"
                android:ems="10"
                android:inputType="numberDecimal" />
 
            <EditText
                android:id="@+id/tf4"
                android:layout_width="175dp"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="20dp"
                android:ems="10"
                android:inputType="numberDecimal" />
 
            <!-- creating  Button for input-->
            <!-- after clicking on button we will see prediction-->
            <Button
                android:id="@+id/button"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="100dp"
                android:text="Button"
                app:layout_constraintBottom_toTopOf="@+id/textView"
                app:layout_constraintEnd_toEndOf="parent"
                app:layout_constraintHorizontal_bias="0.0"
                app:layout_constraintStart_toStartOf="parent" />
 
            <!-- creating  textview on which we will see prediction-->
            <TextView
                android:id="@+id/textView"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content"
                android:layout_gravity="center"
                android:layout_marginTop="50dp"
                android:text="TextView"
                android:textSize="20dp"
                app:layout_constraintEnd_toEndOf="parent" />
        </LinearLayout>
    </ScrollView>
</androidx.constraintlayout.widget.ConstraintLayout>

 
Paso 6: trabajar con el archivo MainActivity.kt

Vaya al archivo MainActivity.kt y consulte el siguiente código. A continuación se muestra el código del archivo MainActivity.kt . Se agregan comentarios dentro del código para comprender el código con más detalle. 

Kotlin

import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import com.example.gfgtfdemo.ml.Iris
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.ByteBuffer
 
class MainActivity : AppCompatActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
 
        // getting the object edit texts
        var ed1: EditText = findViewById(R.id.tf1);
        var ed2: EditText = findViewById(R.id.tf2);
        var ed3: EditText = findViewById(R.id.tf3);
        var ed4: EditText = findViewById(R.id.tf4);
       
        // getting the object of result textview
        var txtView: TextView = findViewById(R.id.textView);
        var b: Button = findViewById<Button>(R.id.button);
 
        // registering listener
        b.setOnClickListener(View.OnClickListener {
           
            val model = Iris.newInstance(this)
 
            // getting values from edit text and converting to float
            var v1: Float = ed1.text.toString().toFloat();
            var v2: Float = ed2.text.toString().toFloat();
            var v3: Float = ed3.text.toString().toFloat();
            var v4: Float = ed4.text.toString().toFloat();
 
            /*************************ML MODEL CODE STARTS HERE******************/
             
              // creating byte buffer which will act as input for model
            var byte_buffer: ByteBuffer = ByteBuffer.allocateDirect(4 * 4)
            byte_buffer.putFloat(v1)
            byte_buffer.putFloat(v2)
            byte_buffer.putFloat(v3)
            byte_buffer.putFloat(v4)
 
            // Creates inputs for reference.
            val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 4), DataType.FLOAT32)
            inputFeature0.loadBuffer(byte_buffer)
 
            // Runs model inference and gets result.
            val outputs = model.process(inputFeature0)
            val outputFeature0 = outputs.outputFeature0AsTensorBuffer.floatArray
 
             
            // setting the result to the output textview
            txtView.setText(
                "Iris-setosa : =" + outputFeature0[0].toString() + "\n" +
                "Iris-versicolor : =" + outputFeature0[1].toString() + "\n" +
                "Iris-virginica: =" +  outputFeature0[2].toString()
            )
 
            // Releases model resources if no longer used.
            model.close()
        })
    }
}

 
Producción: 

Puedes descargar este proyecto desde aquí .

Publicación traducida automáticamente

Artículo escrito por aniketdalal126 y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *