Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.
El método .setWeights() se usa para establecer los pesos de la capa indicada, a partir de los tensores dados.
Sintaxis:
setWeights(weights)
Parámetros:
- pesos: Es la lista indicada de tensores de entrada. Es de tipo tf.Tensor[]. Donde, el conteo de arreglos así como su forma debe ser equivalente al conteo de las dimensiones de los pesos indicados de la capa utilizada. En otras palabras, debe ser igual al resultado del método getWeights() .
Valor devuelto: Devuelve nulo.
Ejemplo 1:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Creating a model const model = tf.sequential(); // Adding a layer model.add(tf.layers.dense({units: 2, inputShape: [11]})); // Calling setWeights() method model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]); // Compiling the model model.compile({loss: 'categoricalCrossentropy', optimizer: 'sgd'}); // Printing output using getWeights() method model.layers[0].getWeights()[0].print();
Producción:
Tensor [[-0.5969906, -0.1883931], [0.8569255 , -0.49416 ], [0.1157023 , 0.1150239 ], [-0.4052143, 1.9936075 ], [0.3090054 , 0.7212474 ], [0.4626641 , -0.7287846], [0.4352857 , -0.5195332], [0.4626429 , 0.0216295 ], [-0.1110666, -0.5997615], [-0.5083916, -0.3582681], [-0.2847465, 1.184485 ]]
Aquí, el método truncatedNormal() se usa para crear un tf.Tensor junto con valores que se muestrean de una distribución normal truncada, el método zeros() se usa para crear un tf.Tensor junto con todos los elementos que se establecen en 0 y getWeights() se usa para imprimir los pesos que se establecieron usando el método setWeights() .
Ejemplo 2:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Creating a model const model = tf.sequential(); // Adding layers model.add(tf.layers.dense({units: 1, inputShape: [5], batchSize: 1, dtype: 'int32'})); model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5})); model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8})); model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12})); // Calling setWeights() method model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]); model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]); // Printing output using getWeights() method model.layers[0].getWeights()[0].print(); model.layers[0].getWeights()[1].print(); model.layers[1].getWeights()[0].print(); model.layers[1].getWeights()[1].print();
Producción:
Tensor [[1], [1], [1], [1], [1]] Tensor [0] Tensor [[1, 1],] Tensor [0, 0]
Aquí, el método ones() se usa para crear un tf.Tensor junto con todos los elementos que se establecen en 1.
Referencia: https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights
Publicación traducida automáticamente
Artículo escrito por nidhi1352singh y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA