Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático y redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.
El método .fit( ) de la clase tf.LayersModel se usa para entrenar el modelo para el número fijo de épocas (iteraciones en un conjunto de datos).
Sintaxis:
fit(x, y, args?)
Parámetros: este método acepta los siguientes parámetros.
- x: Es tf.Tensor que contiene todos los datos de entrada.
- y: Es tf.Tensor que contiene todos los datos de salida.
- args: Es de tipo objeto, sus variables son las siguientes:
- batchSize: Define el número de muestras que se propagarán a través del entrenamiento.
- epochs: Define la iteración sobre los arreglos de datos de entrenamiento.
- detallado: ayuda a mostrar el progreso de cada época. Si el valor es 0, significa que no hay ningún mensaje impreso durante la llamada a fit(). Si el valor es 1, significa que en Node-js, imprime la barra de progreso. En el navegador, no muestra ninguna acción. El valor 1 es el valor predeterminado. 2: el valor 2 aún no está implementado.
- callbacks: Define una lista de callbacks a ser llamados durante el entrenamiento. La variable puede tener una o más de estas devoluciones de llamada onTrainBegin(), onTrainEnd(), onEpochBegin(), onEpochEnd(), onBatchBegin(), onBatchEnd(), onYield().
- ValidationSplit: facilita al usuario dividir el conjunto de datos de entrenamiento en entrenamiento y validación. Por ejemplo: si el valor es validación-Dividir = 0.5, significa usar el último 50% de los datos antes de barajar para la validación.
- ValidationData: se utiliza para dar una estimación del modelo final al seleccionar entre modelos finales.
- barajar: este valor define el barajado de los datos antes de cada época. No tiene efecto cuando stepsPerEpoch no es nulo.
- classWeight: Se utiliza para ponderar la función de pérdida. Puede ser útil decirle al modelo que preste más atención a las muestras de una clase subrepresentada.
- sampleWeight: Es una array de pesos para aplicar a la pérdida del modelo para cada muestra.
- initialEpoch: Es el valor de definir la época en la que comenzar a entrenar. Es útil para reanudar una carrera de entrenamiento anterior.
- stepsPerEpoch: Define un número de lotes de muestras antes de declarar finalizada una época e iniciar la siguiente. Es igual a 1 si no se determina.
- ValidationSteps: es relevante si se especifica stepsPerEpoch . El número total de pasos para validar antes de detenerse.
- yieldEvery: Define la configuración de la frecuencia de cesión del hilo principal a otras tareas. Puede ser automático, lo que significa que el rendimiento ocurre a una determinada velocidad de fotogramas. lote, si el valor es este, produce cada lote. época, si el valor es este, rinde cada época. cualquier número, si el valor es cualquier número, produce cada número en milisegundos . nunca , si el valor es este, nunca cede.
Vuelve: Vuelve la promesa de la historia.
Ejemplo 1:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Defining model const mymodel = tf.sequential({ layers: [tf.layers.dense({units: 2, inputShape: [6]})] }); // Compiling the above model mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); // Using for loop for (let i = 0; i < 4; i++) { // Calling fit() method const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), { batchSize: 5, epochs: 4 }); // Printing output console.log(his.history.loss[1]); }
Producción:
0.9574100375175476 0.8151942491531372 0.694103479385376 0.5909997820854187
Ejemplo 2:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Defining model const mymodel = tf.sequential({ layers: [tf.layers.dense({units: 2, inputShape: [6], activation : "sigmoid"})]}); // Compiling the above model mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); // Calling fit() method const his = await mymodel.fit(tf.truncatedNormal([6, 6]), tf.randomNormal([6, 2]), { batchSize: 5, epochs: 4, validationSplit: 0.2, shuffle: true, initialEpoch: 2, stepsPerEpoch: 1, validationSteps: 2}); // Printing output console.log(JSON.stringify(his.history));
Producción:
{"val_loss":[0.35800713300704956,0.35819053649902344], "loss":[0.633269190788269,0.632409930229187]}
Referencia: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit
Publicación traducida automáticamente
Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA