Tensorflow.js tf. Clase secuencial .fitDataset() Método

Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático y redes neuronales de aprendizaje profundo en el entorno del navegador o del Node. El método Tensorflow.js tf.Sequential class .fitDataset() se usa para entrenar el modelo usando un objeto de conjunto de datos.

Sintaxis:

model.fitDataset(dataset, args);

Parámetros: Este método contiene los siguientes parámetros:

  • conjunto de datos: Es un conjunto de datos de un valor de entrada. Puede ser un conjunto de datos de primitivas, una array o un objeto.
  • args: contiene los siguientes valores:
    • epochs : es el número total de pases en el conjunto de datos de entrenamiento durante el modelo de entrenamiento. Es un valor entero.
    • batchesPerEpoch : Define el número de lotes en cada época. Su valor depende del tamaño del lote a medida que aumenta el tamaño del lote, su tamaño disminuye.
    • detallado: ayuda a mostrar el progreso de cada época. Si el valor es 0, significa que no hay ningún mensaje impreso durante la llamada a fit(). Si el valor es 1, significa que en Node.js, imprime la barra de progreso. En el navegador no muestra ninguna acción. El valor 1 es el valor predeterminado. 2: el valor 2 aún no está implementado.
    • callbacks: Define una lista de callbacks a ser llamados durante el entrenamiento. La variable puede tener una o más de estas devoluciones de llamada onTrainBegin(), onTrainEnd(), onEpochBegin(), onEpochEnd(), onBatchBegin(), onBatchEnd(), onYield().
    • ValidationData: se utiliza para dar una estimación del modelo final al seleccionar entre modelos finales. Esto podría ser cualquiera de estos: una array de [xVal, yVal], un objeto de conjunto de datos con elementos de la forma {xs: xVal, ys: yVal}.
    • ValidationBatchSize: Es el número que define el tamaño del lote. Se utiliza para validar el tamaño del lote. Significa que no podemos poner todos los conjuntos de datos a la vez que excedan este valor. su valor predeterminado es 32.
    • validaciónBatches: Se utiliza para validar los lotes de muestras. Se utiliza para extraer datos de validación con fines de validación en cada final de una época.
    • classWeight: Se utiliza para ponderar la función de pérdida. Puede ser útil decirle al modelo que preste más atención a las muestras de una clase subrepresentada.
    • initialEpoch: Se utiliza para definir el valor de la época en la que comenzar a entrenar. Es útil para reanudar una carrera de entrenamiento anterior.
    • yieldEvery: Define la configuración de la frecuencia de cesión del hilo principal a otras tareas. Puede ser automático, lo que significa que el rendimiento ocurre a una determinada velocidad de fotogramas. lote, si el valor es este, produce cada lote. época, si el valor es este, rinde cada época. cualquier número, si el valor es cualquier número, produce cada número en milisegundos. nunca, si el valor es este, nunca cede.

Devoluciones: Promise<Historial>

Ejemplo 1: En este ejemplo, entrenaremos nuestro modelo utilizando un conjunto de datos de array.

Javascript

import * as tf from "@tensorflow/tfjs"
  
// Creating model
const gfg_Model = tf.sequential() ;
  
// Adding layer to model
const config = {units: 1, inputShape: [2]}
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
  
// Compiling the model
const config2 = {optimizer: 'sgd', loss: 'meanSquaredError'} 
gfg_Model.compile(config2);
  
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 = tf.data.array(array1);
const arrData2 = tf.data.array(array2);
  
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset = tf.data.zip(config3)
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
  
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs: 3 });
  
// Printing the loss after training
console.log("Loss " + " : " + Tm.history.loss[0]);

Producción:

Loss : 0.428712397813797

Ejemplo 2: en este ejemplo, entrenaremos nuestro modelo con un conjunto de datos creado con un archivo csv.

Javascript

import * as tf from "@tensorflow/tfjs";
  
// Path for the CSV file
const gfg_CsvFile =
  "https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv";
  
// Creating model
const gfg_Model = tf.sequential();
  
// Adding layer to model
const config = { units: 1, inputShape: [12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
  
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer: opt, loss: "meanSquaredError" });
  
// Here we want to predict column tax
const config2 = { columnConfigs: { tax: { isLabel: true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
  
// Creating dataset for training
const flattenedDataset = csvDataset
  .map(({ xs, ys }) => {
    return { xs: Object.values(xs), ys: Object.values(ys) };
  })
  .batch(5);
  
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs: 5 });
  
for (let i = 0; i < 5; i++) {
  console.log(Tm.history.loss[i]);
}

Producción:

21489.68359375
8750.29296875
6632.365234375
5908.6171875
5546.45654296875

Referencia: https://js.tensorflow.org/api/latest/#tf.Sequential.fitDataset

Publicación traducida automáticamente

Artículo escrito por satyam00so y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *