Tensorflow.js tf.LayersModel clase .trainOnBatch() Método

Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.

La función .trainOnBatch() se usa para ejecutar una actualización de gradiente separada en un lote de datos en particular.

Nota: este método varía de fit() y fitDataset() de las siguientes maneras:

  • Este método funciona en absolutamente un lote de datos.
  • Este método simplemente devuelve los valores de pérdida y métrica, en lugar de devolver la pérdida de lote por lote, así como los valores de métrica.
  • Este método no favorece las opciones detalladas como la verbosidad y las devoluciones de llamada.

Sintaxis:

trainOnBatch(x, y)

Parámetros:

  • x: Los datos de entrada indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Puede ser cualquiera de los siguientes:
    1. Un tf.Tensor declarado, o bien una array de tf.Tensors si el modelo indicado posee múltiples entradas.
    2. Un objeto que traza nombres de entrada en el tf.Tensor coincidente en caso de que el modelo indicado posea entradas con nombre.
  • y: Los datos de destino indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Debe ser constante con respecto a x .

Valor devuelto: Devuelve promesa de número o número[].

Ejemplo 1:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Training Model
const mymodel = tf.sequential(
     {layers: [tf.layers.dense({units: 2, inputShape: [2]})]});
  
// Compiling our model
const config = {optimizer:'sgd',
            loss:'meanSquaredError'};
mymodel.compile(config);
      
// Test tensor and target tensor
const xs = tf.ones([3,2]);
const ys = tf.ones([3,2]);
      
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
  
// Printing output
console.log(result);

Producción:

2.0696773529052734

Ejemplo 2:

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
async function run() {
  
  // Training Model
  const mymodel = tf.sequential(
     {layers: [tf.layers.dense({units: 2, inputShape: [2], 
                                activation: 'sigmoid'})]});
  
  // Compiling our model
  const config = {optimizer:'sgd',
            loss:'meanSquaredError'};
  mymodel.compile(config);
      
  // Test tensor and target tensor
  const xs = tf.truncatedNormal([3,2]);
  const ys = tf.randomNormal([3,2]);
      
  // Calling trainOneBatch() method
  const result = await mymodel.trainOnBatch(xs, ys);
  
  // Printing output
  console.log(JSON.stringify(+result));
}
    
// Function call
await run();

Producción:

0.5935208797454834

Referencia: https://js.tensorflow.org/api/latest/#tf.LayersModel.trainOnBatch

Publicación traducida automáticamente

Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *