Tensorflow.js tf.GraphModel clase .predict() Método

Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.

La función .predict() se usa para implementar la implicación a favor de los tensores de entrada.

Sintaxis:

predict(inputs, config?)

Parámetros:  

  • entradas: Son las entradas indicadas. Es de tipo (tf.Tensor|tf.Tensor[]|{[name: string]: tf.Tensor}).
  • config: es la configuración de predicción establecida para definir el tamaño del lote, así como las designaciones de los Nodes de salida. Además, en la actualidad, la selección del tamaño del lote se pasa por alto para el modelo gráfico. Es opcional y es de tipo objeto.
    • batchSize: es la dimensión del lote indicada que es opcional y es de tipo entero. En caso de que no esté definido, el valor predeterminado será 32.
    • verbose: Es el modo de verbosidad indicado cuyo valor por defecto es falso y es opcional.

Valor devuelto: Devuelve tf.Tensor|tf.Tensor[]|{[name: string]: tf.Tensor}.

Ejemplo 1: en este ejemplo, cargamos MobileNetV2 desde una URL y mantenemos una predicción con una entrada de ceros.

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Calling predict() method and 
// Printing output
mymodel.predict(inputs).print();

Producción:

Tensor
     [[-0.1800361, -0.4059965, 0.8190175, 
     ..., 
     -0.8953396, -1.0841646, 1.2912753],]

Ejemplo 2: en este ejemplo, cargamos MobileNetV2 desde una URL de TF Hub y mantenemos una predicción con una entrada de ceros.

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements
const model_Url =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
  
// Calling the loadGraphModel() method
const model = await tf.loadGraphModel(
        model_Url, {fromTFHub: true});
  
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
  
// Defining batchsize
const batchsize = 1;
  
// Defining verbose
const verbose = true;
  
// Calling predict() method and
// Printing output
model.predict(inputs, batchsize, verbose).print();

Producción:

Tensor
     [[-1.1690605, 0.0195426, 1.1962479, 
     ..., 
     -0.4825858, -0.0055641, 1.1937635],]

Referencia: https://js.tensorflow.org/api/latest/#tf.GraphModel.predict

Publicación traducida automáticamente

Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *