Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.
La función .trainOnBatch() se usa para ejecutar una actualización de gradiente separada en un lote de datos en particular.
Nota:
Este método varía de fit() y fitDataset() de las siguientes maneras:
- Este método funciona en absolutamente un lote de datos.
- Este método simplemente devuelve los valores de pérdida y métrica, en lugar de devolver la pérdida de lote por lote, así como los valores de métrica.
- Este método no favorece las opciones detalladas como la verbosidad y las devoluciones de llamada.
Sintaxis:
trainOnBatch(x, y)
Parámetros:
- x: Los datos de entrada indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Puede ser cualquiera de los siguientes:
- Un tf.Tensor declarado, o bien una array de tf.Tensors si el modelo indicado posee múltiples entradas.
- Un objeto que traza nombres de entrada en el tf.Tensor coincidente en caso de que el modelo indicado posea entradas con nombre.
- y: Los datos de destino indicados. Puede ser del tipo tf.Tensor, tf.Tensor[] o {[inputName: string]: tf.Tensor}. Debe ser constante con respecto a x.
Valor devuelto: Devuelve la promesa de número o número[].
Ejemplo 1:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Training Model const gfg = tf.sequential(); // Adding layer to model const layer = tf.layers.dense({units:3, inputShape : [5]}); gfg.add(layer); // Compiling our model const config = {optimizer:'sgd', loss:'meanSquaredError'}; gfg.compile(config); // Test tensor and target tensor const xs = tf.ones([3, 5]); const ys = tf.ones([3, 3]); // Calling trainOneBatch() method const result = await gfg.trainOnBatch(xs, ys); // Printing output console.log(result);
Producción:
0.3589147925376892
Ejemplo 2:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" async function run() { // Training Model const gfg = tf.sequential(); // Adding layer to model const layer = tf.layers.dense({units:2, inputShape : [2]}); gfg.add(layer); // Compiling our model const config = {optimizer:'sgd', loss:'meanSquaredError'}; gfg.compile(config); // Test tensor and target tensor const xs = tf.truncatedNormal([3, 2]); const ys = tf.randomNormal([3, 2]); // Calling trainOneBatch() method const result = await gfg.trainOnBatch(xs, ys); // Printing output console.log(JSON.stringify(+result)); } // Function call await run();
Producción:
1.6889342069625854
Referencia: https://js.tensorflow.org/api/latest/#tf.Sequential.trainOnBatch
Publicación traducida automáticamente
Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA