Codificadores automáticos variacionales

El codificador automático variacional fue propuesto en 2013 por Knigma y Welling en Google y Qualcomm. Un autocodificador variacional (VAE) proporciona una forma probabilística de describir una observación en el espacio latente. Por lo tanto, en lugar de construir un codificador que genere un solo valor para describir cada atributo de estado latente, formularemos nuestro codificador para describir una distribución de probabilidad para cada atributo latente.

Tiene muchas aplicaciones, como la compresión de datos, la creación de datos sintéticos, etc.

Arquitectura:

Los codificadores automáticos son un tipo de red neuronal que aprende las codificaciones de datos del conjunto de datos sin supervisión. Básicamente contiene dos partes: la primera es un codificador que es similar a la red neuronal de convolución excepto por la última capa. El objetivo del codificador es aprender la codificación de datos eficiente del conjunto de datos y pasarlo a una arquitectura de cuello de botella. La otra parte del codificador automático es un decodificador que utiliza el espacio latente en la capa de cuello de botella para regenerar imágenes similares al conjunto de datos. Estos resultados se retropropagan desde la red neuronal en forma de función de pérdida.

El codificador automático variacional se diferencia del codificador automático en que proporciona una manera estadística de describir las muestras del conjunto de datos en el espacio latente. Por lo tanto, en el codificador automático variacional, el codificador genera una distribución de probabilidad en la capa de cuello de botella en lugar de un único valor de salida.

Matemáticas detrás del codificador automático variacional:

El codificador automático variacional usa la divergencia KL como su función de pérdida, el objetivo de esto es minimizar la diferencia entre una supuesta distribución y la distribución original del conjunto de datos.

Supongamos que tenemos una distribución z y queremos generar la observación x a partir de ella. En otras palabras, queremos calcular 

p\left( {z|x} \right)
 

Podemos hacerlo de la siguiente manera:

p\left( {z|x} \right) = \frac{{p\left( {x|z} \right)p\left( z \right)}}{{p\left( x \right)}}
 

Pero, el cálculo de p(x) puede ser bastante difícil

p\left( x \right) = \int {p\left( {x|z} \right)p\left(z\right)dz}
 

Esto por lo general hace que sea una distribución intratable. Por lo tanto, necesitamos aproximar p(z|x) a q(z|x) para que sea una distribución manejable. Para aproximar mejor p(z|x) a q(z|x), minimizaremos la pérdida de divergencia KL que calcula qué tan similares son dos distribuciones:

\min KL\left( {q\left( {z|x} \right)||p\left( {z|x} \right)} \right)
 

Simplificando, el problema de minimización anterior es equivalente al siguiente problema de maximización:

{E_{q\left( {z|x} \right)}}\log p\left( {x|z} \right) - KL\left( {q\left( {z|x} \right)||p\left( z \right)} \right)
 

El primer término representa la probabilidad de reconstrucción y el otro término asegura que nuestra distribución aprendida q es similar a la verdadera distribución previa p.

Por lo tanto, nuestra pérdida total consta de dos términos, uno es el error de reconstrucción y el otro es la pérdida por divergencia KL:

Loss = L\left( {x, \hat x} \right) + \sum\limits_j {KL\left( {{q_j}\left( {z|x} \right)||p\left( z \right)} \right)}
 

Implementación:

En esta implementación, utilizaremos el conjunto de datos Fashion-MNIST, este conjunto de datos ya está disponible en la API keras.datasets, por lo que no es necesario agregarlo o cargarlo manualmente.

  • Primero, necesitamos importar los paquetes necesarios a nuestro entorno de python. Usaremos el paquete Keras con tensorflow como backend.

Código: 

python3

# code
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Layer, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose
import matplotlib.pyplot as plt
  • Para los codificadores automáticos variacionales, necesitamos definir la arquitectura de codificador y decodificador de dos partes, pero primero, definiremos la capa de cuello de botella de la arquitectura, la capa de muestreo.

Código: 

python3

# this sampling layer is the bottleneck layer of variational autoencoder,
# it uses the output from two dense layers z_mean and z_log_var as input,
# convert them into normal distribution and pass them to the decoder layer
class Sampling(Layer):
 
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape =(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
  • Ahora, definimos la arquitectura de la parte del codificador de nuestro codificador automático, esta parte toma imágenes como entrada y codifica su representación en la capa de muestreo.

Código: 

python3

# Define Encoder Model
latent_dim = 2
 
encoder_inputs = Input(shape =(28, 28, 1))
x = Conv2D(32, 3, activation ="relu", strides = 2, padding ="same")(encoder_inputs)
x = Conv2D(64, 3, activation ="relu", strides = 2, padding ="same")(x)
x = Flatten()(x)
x = Dense(16, activation ="relu")(x)
z_mean = Dense(latent_dim, name ="z_mean")(x)
z_log_var = Dense(latent_dim, name ="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name ="encoder")
encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 32)   320         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 7, 7, 64)     18496       conv2d_2[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 3136)         0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 16)           50192       flatten_1[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense_2[0][0]                    
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense_2[0][0]                    
__________________________________________________________________________________________________
sampling_1 (Sampling)           (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 69, 076
Trainable params: 69, 076
Non-trainable params: 0
__________________________________________________________________________________________________
  • Ahora, definimos la arquitectura de la parte del decodificador de nuestro autocodificador, esta parte toma la salida de la capa de muestreo como entrada y genera una imagen de tamaño (28, 28, 1).

Código: 

python3

# Define Decoder Architecture
latent_inputs = keras.Input(shape =(latent_dim, ))
x = Dense(7 * 7 * 64, activation ="relu")(latent_inputs)
x = Reshape((7, 7, 64))(x)
x = Conv2DTranspose(64, 3, activation ="relu", strides = 2, padding ="same")(x)
x = Conv2DTranspose(32, 3, activation ="relu", strides = 2, padding ="same")(x)
decoder_outputs = Conv2DTranspose(1, 3, activation ="sigmoid", padding ="same")(x)
decoder = Model(latent_inputs, decoder_outputs, name ="decoder")
decoder.summary()
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_3 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 28, 28, 1)         289       
=================================================================
Total params: 65, 089
Trainable params: 65, 089
Non-trainable params: 0
_________________________________________________________________
  • En este paso, combinamos el modelo y definimos el procedimiento de entrenamiento con funciones de pérdida.

Código: 

python3

# this class takes encoder and decoder models and
# define the complete variational autoencoder architecture
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
 
    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }
  • Ahora es el momento adecuado para entrenar nuestro modelo de codificador automático variacional, lo entrenaremos durante 100 épocas. Pero primero necesitamos importar el conjunto de datos MNIST de moda.

Código: 

python3

# load fashion mnist dataset  from  keras.dataset API
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
fmnist_images = np.concatenate([x_train, x_test], axis = 0)
# expand dimension to add  a color map dimension
fmnist_images = np.expand_dims(fmnist_images, -1).astype("float32") / 255
 
# compile and train the model
vae = VAE(encoder, decoder)
vae.compile(optimizer ='rmsprop')
vae.fit(fmnist_images, epochs = 100, batch_size = 64)
Epoch 1/100
1094/1094 [==============================] - 7s 6ms/step - loss: 301.9441 - reconstruction_loss: 298.3138 - kl_loss: 3.6303
Epoch 2/100
1094/1094 [==============================] - 7s 6ms/step - loss: 273.5940 - reconstruction_loss: 270.0484 - kl_loss: 3.5456
Epoch 3/100
1094/1094 [==============================] - 7s 6ms/step - loss: 269.3337 - reconstruction_loss: 265.9077 - kl_loss: 3.4260
Epoch 4/100
1094/1094 [==============================] - 7s 6ms/step - loss: 266.8168 - reconstruction_loss: 263.4100 - kl_loss: 3.4068
Epoch 5/100
1094/1094 [==============================] - 7s 6ms/step - loss: 264.9917 - reconstruction_loss: 261.5603 - kl_loss: 3.4314
Epoch 6/100
1094/1094 [==============================] - 7s 6ms/step - loss: 263.5237 - reconstruction_loss: 260.0712 - kl_loss: 3.4525
Epoch 7/100
1094/1094 [==============================] - 7s 6ms/step - loss: 262.3414 - reconstruction_loss: 258.8548 - kl_loss: 3.4865
Epoch 8/100
1094/1094 [==============================] - 7s 6ms/step - loss: 261.4241 - reconstruction_loss: 257.9104 - kl_loss: 3.5137
Epoch 9/100
1094/1094 [==============================] - 7s 6ms/step - loss: 260.6090 - reconstruction_loss: 257.0662 - kl_loss: 3.5428
Epoch 10/100
1094/1094 [==============================] - 7s 6ms/step - loss: 259.9735 - reconstruction_loss: 256.4075 - kl_loss: 3.5660
Epoch 11/100
1094/1094 [==============================] - 7s 6ms/step - loss: 259.4184 - reconstruction_loss: 255.8348 - kl_loss: 3.5836
Epoch 12/100
1094/1094 [==============================] - 7s 6ms/step - loss: 258.9688 - reconstruction_loss: 255.3724 - kl_loss: 3.5964
Epoch 13/100
1094/1094 [==============================] - 7s 6ms/step - loss: 258.5413 - reconstruction_loss: 254.9356 - kl_loss: 3.6057
Epoch 14/100
1094/1094 [==============================] - 7s 6ms/step - loss: 258.2400 - reconstruction_loss: 254.6236 - kl_loss: 3.6163
Epoch 15/100
1094/1094 [==============================] - 7s 6ms/step - loss: 257.9335 - reconstruction_loss: 254.3038 - kl_loss: 3.6298
Epoch 16/100
1094/1094 [==============================] - 7s 6ms/step - loss: 257.6331 - reconstruction_loss: 253.9993 - kl_loss: 3.6339
Epoch 17/100
1094/1094 [==============================] - 7s 6ms/step - loss: 257.4199 - reconstruction_loss: 253.7707 - kl_loss: 3.6492
Epoch 18/100
1094/1094 [==============================] - 6s 6ms/step - loss: 257.1951 - reconstruction_loss: 253.5309 - kl_loss: 3.6643
Epoch 19/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.9326 - reconstruction_loss: 253.2723 - kl_loss: 3.6604
Epoch 20/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.7551 - reconstruction_loss: 253.0836 - kl_loss: 3.6715
Epoch 21/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.5663 - reconstruction_loss: 252.8877 - kl_loss: 3.6786
Epoch 22/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.4068 - reconstruction_loss: 252.7112 - kl_loss: 3.6956
Epoch 23/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.2588 - reconstruction_loss: 252.5588 - kl_loss: 3.7000
Epoch 24/100
1094/1094 [==============================] - 7s 6ms/step - loss: 256.0853 - reconstruction_loss: 252.3794 - kl_loss: 3.7059
Epoch 25/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.9321 - reconstruction_loss: 252.2201 - kl_loss: 3.7120
Epoch 26/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.7962 - reconstruction_loss: 252.0814 - kl_loss: 3.7148
Epoch 27/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.6953 - reconstruction_loss: 251.9673 - kl_loss: 3.7280
Epoch 28/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.5534 - reconstruction_loss: 251.8248 - kl_loss: 3.7287
Epoch 29/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.4437 - reconstruction_loss: 251.7134 - kl_loss: 3.7303
Epoch 30/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.3439 - reconstruction_loss: 251.6064 - kl_loss: 3.7375
Epoch 31/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.2326 - reconstruction_loss: 251.5018 - kl_loss: 3.7308
Epoch 32/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.1356 - reconstruction_loss: 251.3933 - kl_loss: 3.7423
Epoch 33/100
1094/1094 [==============================] - 7s 6ms/step - loss: 255.0660 - reconstruction_loss: 251.3224 - kl_loss: 3.7436
Epoch 34/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.9977 - reconstruction_loss: 251.2449 - kl_loss: 3.7528
Epoch 35/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.8857 - reconstruction_loss: 251.1363 - kl_loss: 3.7494
Epoch 36/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.7980 - reconstruction_loss: 251.0481 - kl_loss: 3.7499
Epoch 37/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.7485 - reconstruction_loss: 250.9851 - kl_loss: 3.7634
Epoch 38/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.6701 - reconstruction_loss: 250.9049 - kl_loss: 3.7652
Epoch 39/100
1094/1094 [==============================] - 6s 6ms/step - loss: 254.6105 - reconstruction_loss: 250.8389 - kl_loss: 3.7716
Epoch 40/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.4979 - reconstruction_loss: 250.7333 - kl_loss: 3.7646
Epoch 41/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.4734 - reconstruction_loss: 250.7037 - kl_loss: 3.7697
Epoch 42/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.4408 - reconstruction_loss: 250.6576 - kl_loss: 3.7831
Epoch 43/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.3272 - reconstruction_loss: 250.5562 - kl_loss: 3.7711
Epoch 44/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.3110 - reconstruction_loss: 250.5354 - kl_loss: 3.7755
Epoch 45/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.1982 - reconstruction_loss: 250.4256 - kl_loss: 3.7726
Epoch 46/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.1655 - reconstruction_loss: 250.3795 - kl_loss: 3.7860
Epoch 47/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.0979 - reconstruction_loss: 250.3105 - kl_loss: 3.7875
Epoch 48/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.0801 - reconstruction_loss: 250.2973 - kl_loss: 3.7828
Epoch 49/100
1094/1094 [==============================] - 7s 6ms/step - loss: 254.0101 - reconstruction_loss: 250.2270 - kl_loss: 3.7831
Epoch 50/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.9512 - reconstruction_loss: 250.1681 - kl_loss: 3.7831
Epoch 51/100
1094/1094 [==============================] - 7s 7ms/step - loss: 253.9307 - reconstruction_loss: 250.1408 - kl_loss: 3.7899
Epoch 52/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.8858 - reconstruction_loss: 250.1059 - kl_loss: 3.7800
Epoch 53/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.8118 - reconstruction_loss: 250.0236 - kl_loss: 3.7882
Epoch 54/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.8171 - reconstruction_loss: 250.0325 - kl_loss: 3.7845
Epoch 55/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.7622 - reconstruction_loss: 249.9735 - kl_loss: 3.7887
Epoch 56/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.7338 - reconstruction_loss: 249.9380 - kl_loss: 3.7959
Epoch 57/100
1094/1094 [==============================] - 6s 6ms/step - loss: 253.6761 - reconstruction_loss: 249.8792 - kl_loss: 3.7969
Epoch 58/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.6236 - reconstruction_loss: 249.8283 - kl_loss: 3.7954
Epoch 59/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.6181 - reconstruction_loss: 249.8236 - kl_loss: 3.7945
Epoch 60/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.5509 - reconstruction_loss: 249.7587 - kl_loss: 3.7921
Epoch 61/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.5124 - reconstruction_loss: 249.7126 - kl_loss: 3.7998
Epoch 62/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.4739 - reconstruction_loss: 249.6683 - kl_loss: 3.8056
Epoch 63/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.4609 - reconstruction_loss: 249.6567 - kl_loss: 3.8042
Epoch 64/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.4066 - reconstruction_loss: 249.6020 - kl_loss: 3.8045
Epoch 65/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.3578 - reconstruction_loss: 249.5580 - kl_loss: 3.7998
Epoch 66/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.3728 - reconstruction_loss: 249.5609 - kl_loss: 3.8118
Epoch 67/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.3523 - reconstruction_loss: 249.5351 - kl_loss: 3.8171
Epoch 68/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.2646 - reconstruction_loss: 249.4452 - kl_loss: 3.8194
Epoch 69/100
1094/1094 [==============================] - 6s 6ms/step - loss: 253.2642 - reconstruction_loss: 249.4603 - kl_loss: 3.8040
Epoch 70/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.2227 - reconstruction_loss: 249.4159 - kl_loss: 3.8068
Epoch 71/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.1848 - reconstruction_loss: 249.3755 - kl_loss: 3.8094
Epoch 72/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.1812 - reconstruction_loss: 249.3737 - kl_loss: 3.8074
Epoch 73/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.1803 - reconstruction_loss: 249.3743 - kl_loss: 3.8059
Epoch 74/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.1295 - reconstruction_loss: 249.3114 - kl_loss: 3.8181
Epoch 75/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.0516 - reconstruction_loss: 249.2391 - kl_loss: 3.8125
Epoch 76/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.0736 - reconstruction_loss: 249.2582 - kl_loss: 3.8154
Epoch 77/100
1094/1094 [==============================] - 6s 6ms/step - loss: 253.0331 - reconstruction_loss: 249.2200 - kl_loss: 3.8131
Epoch 78/100
1094/1094 [==============================] - 7s 6ms/step - loss: 253.0479 - reconstruction_loss: 249.2272 - kl_loss: 3.8207
Epoch 79/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.9317 - reconstruction_loss: 249.1137 - kl_loss: 3.8179
Epoch 80/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.9578 - reconstruction_loss: 249.1483 - kl_loss: 3.8095
Epoch 81/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.9072 - reconstruction_loss: 249.0963 - kl_loss: 3.8109
Epoch 82/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.8793 - reconstruction_loss: 249.0646 - kl_loss: 3.8147
Epoch 83/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.8914 - reconstruction_loss: 249.0676 - kl_loss: 3.8238
Epoch 84/100
1094/1094 [==============================] - 6s 6ms/step - loss: 252.8365 - reconstruction_loss: 249.0121 - kl_loss: 3.8244
Epoch 85/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.8063 - reconstruction_loss: 248.9844 - kl_loss: 3.8218
Epoch 86/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.7960 - reconstruction_loss: 248.9777 - kl_loss: 3.8183
Epoch 87/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.7733 - reconstruction_loss: 248.9529 - kl_loss: 3.8204
Epoch 88/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.7303 - reconstruction_loss: 248.9055 - kl_loss: 3.8248
Epoch 89/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.7225 - reconstruction_loss: 248.8902 - kl_loss: 3.8323
Epoch 90/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.6822 - reconstruction_loss: 248.8549 - kl_loss: 3.8273
Epoch 91/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.6540 - reconstruction_loss: 248.8314 - kl_loss: 3.8227
Epoch 92/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.6540 - reconstruction_loss: 248.8239 - kl_loss: 3.8300
Epoch 93/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.6213 - reconstruction_loss: 248.7778 - kl_loss: 3.8435
Epoch 94/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.5990 - reconstruction_loss: 248.7594 - kl_loss: 3.8397
Epoch 95/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.5786 - reconstruction_loss: 248.7413 - kl_loss: 3.8373
Epoch 96/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.5839 - reconstruction_loss: 248.7411 - kl_loss: 3.8427
Epoch 97/100
1094/1094 [==============================] - 7s 7ms/step - loss: 252.5364 - reconstruction_loss: 248.6960 - kl_loss: 3.8404
Epoch 98/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.5347 - reconstruction_loss: 248.6915 - kl_loss: 3.8431
Epoch 99/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.4996 - reconstruction_loss: 248.6569 - kl_loss: 3.8428
Epoch 100/100
1094/1094 [==============================] - 7s 6ms/step - loss: 252.4938 - reconstruction_loss: 248.6405 - kl_loss: 3.8533
<tensorflow.python.keras.callbacks.History at 0x7f5467c56be0>
  • En este paso, mostramos los resultados del entrenamiento, mostraremos estos resultados según sus valores en vectores espaciales latentes.

Código: 

python3

def plot_latent(encoder, decoder):
    # display a n * n 2D manifold of images
    n = 10
    img_dim = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((img_dim * n, img_dim * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of images classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]
 
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            images = x_decoded[0].reshape(img_dim, img_dim)
            figure[
                i * img_dim : (i + 1) * img_dim,
                j * img_dim : (j + 1) * img_dim,
            ] = images
 
    plt.figure(figsize =(figsize, figsize))
    start_range = img_dim // 2
    end_range = n * img_dim + start_range + 1
    pixel_range = np.arange(start_range, end_range, img_dim)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap ="Greys_r")
    plt.show()
 
 
plot_latent(encoder, decoder)

  • Para obtener una visión más clara de los valores de nuestros vectores latentes de representación, trazaremos el diagrama de dispersión de los datos de entrenamiento sobre la base de sus valores de las dimensiones latentes correspondientes generadas por el codificador.

Código: 

python3

def plot_label_clusters(encoder, decoder, data, test_lab):
    z_mean, _, _ = encoder.predict(data)
    plt.figure(figsize =(12, 10))
    sc = plt.scatter(z_mean[:, 0], z_mean[:, 1], c = test_lab)
    cbar = plt.colorbar(sc, ticks = range(10))
    cbar.ax.set_yticklabels([labels.get(i) for i in range(10)])
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()
 
 
labels = {0    :"T-shirt / top",
1:    "Trouser",
2:    "Pullover",
3:    "Dress",
4:    "Coat",
5:    "Sandal",
6:    "Shirt",
7:    "Sneaker",
8:    "Bag",
9:    "Ankle boot"}
 
(x_train, y_train), _ = keras.datasets.fashion_mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255
plot_label_clusters(encoder, decoder, x_train, y_train)

Referencias:

Publicación traducida automáticamente

Artículo escrito por pawangfg y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *