Tensorflow.js tf. Clase secuencial .trainOnBatch() Método

Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.

La función .trainOnBatch() se usa para ejecutar una actualización de gradiente separada en un lote de datos en particular.

Nota:

Este método varía de fit() y fitDataset() de las siguientes maneras:

  • Este método funciona en absolutamente un lote de datos.
  • Este método simplemente devuelve los valores de pérdida y métrica, en lugar de devolver la pérdida de lote por lote, así como los valores de métrica.
  • Este método no favorece las opciones detalladas como la verbosidad y las devoluciones de llamada.

Sintaxis:

trainOnBatch(x, y)

Parámetros:

  • x: Los datos de entrada indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Puede ser cualquiera de los siguientes:
    1. Un tf.Tensor declarado, o bien una array de tf.Tensors si el modelo indicado posee múltiples entradas.
    2. Un objeto que traza nombres de entrada en el tf.Tensor coincidente en caso de que el modelo indicado posea entradas con nombre.
  • y: Los datos de destino indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Debe ser constante con respecto a x.

Valor devuelto: Devuelve la promesa de número o número[].

Ejemplo 1:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Training Model 
const gfg = tf.sequential();
    
// Adding layer to model  
const layer = tf.layers.dense({units:3, 
               inputShape : [5]});
   gfg.add(layer);
      
// Compiling our model 
const config = {optimizer:'sgd', 
              loss:'meanSquaredError'};
  gfg.compile(config);
  
// Test tensor and target tensor
const xs = tf.ones([3, 5]);
const ys = tf.ones([3, 3]);
      
// Calling trainOneBatch() method
const result = await gfg.trainOnBatch(xs, ys);
  
// Printing output
console.log(result);

Producción:

0.3589147925376892

Ejemplo 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
async function run() {
  
  // Training Model 
  const gfg = tf.sequential();
    
  // Adding layer to model  
  const layer = tf.layers.dense({units:2, 
               inputShape : [2]});
  gfg.add(layer);
      
  // Compiling our model 
  const config = {optimizer:'sgd', 
              loss:'meanSquaredError'};
  gfg.compile(config);
  
  // Test tensor and target tensor
  const xs = tf.truncatedNormal([3, 2]);
  const ys = tf.randomNormal([3, 2]);
      
  // Calling trainOneBatch() method
  const result = await gfg.trainOnBatch(xs, ys);
  
  // Printing output
  console.log(JSON.stringify(+result));
}
    
// Function call
await run();

Producción:

1.6889342069625854

Referencia: https://js.tensorflow.org/api/latest/#tf.Sequential.trainOnBatch

Publicación traducida automáticamente

Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *