Tensorflow.js es un conjunto de herramientas de código abierto desarrollado por Google para ejecutar modelos de aprendizaje automático y redes neuronales de aprendizaje profundo en el navegador o en la plataforma del Node. También permite a los desarrolladores crear modelos de aprendizaje automático en JavaScript y utilizarlos directamente en el navegador o con Node.js.
La función tf.basicLSTMCell() calcula el siguiente estado y la salida de un BasicLSTMCell.
Sintaxis:
tf.basicLSTMCell (forgetBias, lstmKernel, lstmBias, data, c, h)
Parámetros:
- ForgetBias: el sesgo de olvido de la celda.
- lstmKernel: los pesos de la celda.
- lstmBias: El sesgo de la celda.
- data: La entrada a la celda.
- c: array de estados de celda anteriores.
- h: Array de salidas de celdas anteriores.
Devuelve: [tf.Tensor2D, tf.Tensor2D]
Ejemplo 1:
Javascript
import * as tf from "@tensorflow/tfjs"; const data = tf.tensor2d([7, 51, 50, 54, 24, 1, 48, 75], [4, 2]); const kernel = tf.tensor2d([49, 62, 47, 93, 12, 80, 24, 89, 34, 8, 96, 74, 56, 42, 32, 53, 7, 87, 35, 54], [5, 4]); const state = tf.tensor2d([97, 56, 32, 29, 57, 6, 8, 75, 26, 20, 1, 17], [4, 3]); const output = tf.tensor2d([27, 77, 90, 72, 9, 8, 94, 41, 89, 51, 18, 60], [4, 3]); const basicLSTMCell = tf.basicLSTMCell(0.8, kernel, 2.2, data, state, output); console.log(basicLSTMCella)
Producción:
[ Tensor { kept: false, isDisposedInternal: false, shape: [ 4, 3 ], dtype: 'float32', size: 12, strides: [ 3 ], dataId: { id: 19 }, id: 19, rankType: '2', scopeId: 0 }, Tensor { kept: false, isDisposedInternal: false, shape: [ 4, 3 ], dtype: 'float32', size: 12, strides: [ 3 ], dataId: { id: 22 }, id: 22, rankType: '2', scopeId: 0 } ]
Ejemplo 2:
Javascript
import * as tf from "@tensorflow/tfjs"; const data = tf.tensor2d([70, 10, 62, 55, 74, 85, 66, 9], [4, 2]); const kernel = tf.tensor2d([10, 82, 93, 83, 49, 73, 45, 77, 56, 29, 32, 2, 24, 39, 34, 91, 95, 61, 76, 69], [5, 4]); const state = tf.tensor2d([29, 40, 79, 61, 5, 34, 78, 47, 86, 74, 46, 28], [4, 3]); const output = tf.tensor2d([25, 55, 33, 85, 82, 65, 20, 75, 54, 59, 50, 3], [4, 3]); const basicLSTMCell = tf.basicLSTMCell(1.0, kernel, 2.0, data, state, output); const input = tf.input({ shape: [4, 2] }); const simpleRNNLayer = tf.layers.simpleRNN({ units: 4, returnSequences: true, returnState: true, cell: basicLSTMCell }); let outputs, finalState; [outputs, finalState] = simpleRNNLayer.apply(input); const model = tf.model({ inputs: input, outputs: outputs }); const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [1, 4, 2]); model.predict(x).print();
Producción:
Tensor [[[0.8135326, -0.8665518, 0.946215 , 0.8714994], [0.9547493, -0.9747651, 0.9873405, 0.9995403], [0.9983249, -0.9986398, 0.9996439, 0.9999973], [0.9999447, -0.9999344, 0.9999925, 1 ]]]
Referencia: https://js.tensorflow.org/api/latest/#basicLSTMCell
Publicación traducida automáticamente
Artículo escrito por aayushmohansinha y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA