Función Tensorflow.js tf.loadGraphModel()

Tensorflow.js es una biblioteca de código abierto desarrollada por Google para ejecutar modelos de aprendizaje automático, así como redes neuronales de aprendizaje profundo en el entorno del navegador o del Node.

La función .loadGraphModel() se usa para cargar un modelo gráfico dado un URL a la definición del modelo.

Sintaxis:

tf.loadGraphModel (modelUrl, options)

Parámetros:

  • modelUrl: la primera entrada de tensor que puede ser de tipo string o io.IOHandler. Este parámetro es la URL o un io.IOHandler que ayuda a cargar los modelos.
  • options: La segunda entrada de tensor que es opcional. Las opciones son para la solicitud HTTP, que permite enviar credenciales y encabezados personalizados. Los tipos de opciones son:
    • requestInit: RequestInit son para requests HTTP.
    • onProgress: OnProgress es para devolución de llamada de progreso.
    • fetchFunc: es una función utilizada para anular la función window.fetch.
    • estricto: Estricto es un modelo de carga: ya sea que se trate de un peso extraño o de que falten pesos, debería generar un error.
    • weightPathPrefix: el prefijo de ruta es para archivos de peso que, de forma predeterminada, se calcula a partir de la ruta del archivo JSON del modelo.
    • fromTFHub: es un valor booleano que indica si el módulo o modelo se cargará desde TF Hub.

Valor devuelto: Devuelve la Promesa <tf.GraphModel>.

Ejemplo 1: en este ejemplo, estamos cargando MobileNetV2 desde una URL y haciendo una predicción con una entrada de ceros.

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(modelUrl);
  
// Printing the zeroes
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Producción:

Tensor
     [[-0.1412081, -0.5656458, 0.7578365, ..., 
     -1.0148169, -0.81284, 1.1898142],]

Ejemplo 2: en este ejemplo, estamos cargando MobileNetV2 desde una URL de TF Hub y haciendo una predicción con una entrada de ceros.

Javascript

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(
        modelUrl, {fromTFHub: true});
  
// Printing the zeores
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Producción:

Tensor
     [[-1.0764486, 0.0097444, 1.1630495, ..., 
     -0.345558, 0.035432, 0.9112286],]

Referencia: https://js.tensorflow.org/api/1.0.0/#loadGraphModel

Publicación traducida automáticamente

Artículo escrito por sweetyty y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *