Prerrequisitos: Red Adversaria Generativa
Este artículo demostrará cómo construir una Red Adversaria Generativa utilizando la biblioteca de Keras. El conjunto de datos que se utiliza es el conjunto de datos de imagen CIFAR10 que está precargado en Keras. Puede leer sobre el conjunto de datos aquí .
Paso 1: Importación de las bibliotecas requeridas
import numpy as np import matplotlib.pyplot as plt import keras from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam,SGD
Paso 2: Cargando los datos
#Loading the CIFAR10 data (X, y), (_, _) = keras.datasets.cifar10.load_data() #Selecting a single class images #The number was randomly chosen and any number #between 1 to 10 can be chosen X = X[y.flatten() == 8]
Paso 3: Definición de parámetros a utilizar en procesos posteriores
#Defining the Input shape image_shape = (32, 32, 3) latent_dimensions = 100
Paso 4: Definición de una función de utilidad para construir el Generador
def build_generator(): model = Sequential() #Building the input layer model.add(Dense(128 * 8 * 8, activation="relu", input_dim=latent_dimensions)) model.add(Reshape((8, 8, 128))) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=3, padding="same")) model.add(BatchNormalization(momentum=0.78)) model.add(Activation("relu")) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding="same")) model.add(BatchNormalization(momentum=0.78)) model.add(Activation("relu")) model.add(Conv2D(3, kernel_size=3, padding="same")) model.add(Activation("tanh")) #Generating the output image noise = Input(shape=(latent_dimensions,)) image = model(noise) return Model(noise, image)
Paso 5: Definición de una función de utilidad para construir el Discriminador
def build_discriminator(): #Building the convolutional layers #to classify whether an image is real or fake model = Sequential() model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=image_shape, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=3, strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0,1),(0,1)))) model.add(BatchNormalization(momentum=0.82)) model.add(LeakyReLU(alpha=0.25)) model.add(Dropout(0.25)) model.add(Conv2D(128, kernel_size=3, strides=2, padding="same")) model.add(BatchNormalization(momentum=0.82)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(256, kernel_size=3, strides=1, padding="same")) model.add(BatchNormalization(momentum=0.8)) model.add(LeakyReLU(alpha=0.25)) model.add(Dropout(0.25)) #Building the output layer model.add(Flatten()) model.add(Dense(1, activation='sigmoid')) image = Input(shape=image_shape) validity = model(image) return Model(image, validity)
Paso 6: Definición de una función de utilidad para mostrar las imágenes generadas
def display_images(): r, c = 4,4 noise = np.random.normal(0, 1, (r * c,latent_dimensions)) generated_images = generator.predict(noise) #Scaling the generated images generated_images = 0.5 * generated_images + 0.5 fig, axs = plt.subplots(r, c) count = 0 for i in range(r): for j in range(c): axs[i,j].imshow(generated_images[count, :,:,]) axs[i,j].axis('off') count += 1 plt.show() plt.close()
Paso 7: Construyendo la Red Adversaria Generativa
# Building and compiling the discriminator discriminator = build_discriminator() discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002,0.5), metrics=['accuracy']) #Making the Discriminator untrainable #so that the generator can learn from fixed gradient discriminator.trainable = False # Building the generator generator = build_generator() #Defining the input for the generator #and generating the images z = Input(shape=(latent_dimensions,)) image = generator(z) #Checking the validity of the generated image valid = discriminator(image) #Defining the combined model of the Generator and the Discriminator combined_network = Model(z, valid) combined_network.compile(loss='binary_crossentropy', optimizer=Adam(0.0002,0.5))
Paso 8: Entrenamiento de la red
num_epochs=15000 batch_size=32 display_interval=2500 losses=[] #Normalizing the input X = (X / 127.5) - 1. #Defining the Adversarial ground truths valid = np.ones((batch_size, 1)) #Adding some noise valid += 0.05 * np.random.random(valid.shape) fake = np.zeros((batch_size, 1)) fake += 0.05 * np.random.random(fake.shape) for epoch in range(num_epochs): #Training the Discriminator #Sampling a random half of images index = np.random.randint(0, X.shape[0], batch_size) images = X[index] #Sampling noise and generating a batch of new images noise = np.random.normal(0, 1, (batch_size, latent_dimensions)) generated_images = generator.predict(noise) #Training the discriminator to detect more accurately #whether a generated image is real or fake discm_loss_real = discriminator.train_on_batch(images, valid) discm_loss_fake = discriminator.train_on_batch(generated_images, fake) discm_loss = 0.5 * np.add(discm_loss_real, discm_loss_fake) #Training the Generator #Training the generator to generate images #which pass the authenticity test genr_loss = combined_network.train_on_batch(noise, valid) #Tracking the progress if epoch % display_interval == 0: display_images()
Época 0:
Época 2500:
Época 5000:
Época 7500:
Época 10000:
Época 12500:
Tenga en cuenta que la calidad de las imágenes aumenta con cada época.
Paso 8: Evaluación del desempeño
El rendimiento de la red se evaluará comparando visualmente las imágenes generadas en la última época con las imágenes originales.
a) Trazado de las imágenes originales
#Plotting some of the original images s=X[:40] s = 0.5 * s + 0.5 f, ax = plt.subplots(5,8, figsize=(16,10)) for i, image in enumerate(s): ax[i//8, i%8].imshow(image) ax[i//8, i%8].axis('off') plt.show()
b) Trazado de las imágenes generadas en la última época
#Plotting some of the last batch of generated images noise = np.random.normal(size=(40, latent_dimensions)) generated_images = generator.predict(noise) generated_images = 0.5 * generated_images + 0.5 f, ax = plt.subplots(5,8, figsize=(16,10)) for i, image in enumerate(generated_images): ax[i//8, i%8].imshow(image) ax[i//8, i%8].axis('off') plt.show()
Al comparar visualmente los dos conjuntos de imágenes, se puede concluir que la red está funcionando a un nivel aceptable. La calidad de las imágenes se puede mejorar entrenando la red durante más tiempo o ajustando los parámetros de la red.
Publicación traducida automáticamente
Artículo escrito por AlindGupta y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA