¿Cómo usar un DataLoader en PyTorch?

Operar con grandes conjuntos de datos requiere cargarlos en la memoria todos a la vez. En la mayoría de los casos, nos enfrentamos a una interrupción de la memoria debido a la cantidad limitada de memoria disponible en el sistema. Además, los programas tienden a ejecutarse lentamente debido a los grandes conjuntos de datos que se cargan una vez. PyTorch ofrece una solución para paralelizar el proceso de carga de datos con procesamiento por lotes automático utilizando DataLoader. El cargador de datos se ha utilizado para paralelizar la carga de datos, ya que aumenta la velocidad y ahorra memoria.

El constructor del cargador de datos reside en el paquete torch.utils.data. Tiene varios parámetros, entre los cuales el único argumento obligatorio que se debe pasar es el conjunto de datos que se debe cargar, y el resto son argumentos opcionales.

Sintaxis:

Cargador de datos (conjunto de datos, aleatorio = verdadero, muestra = ninguno, tamaño de lote = 32)

Cargadores de datos en conjuntos de datos personalizados:

Para implementar cargadores de datos en un conjunto de datos personalizado, debemos anular las siguientes dos funciones de subclase: 

  • La función _len_(): devuelve el tamaño del conjunto de datos.
  • La función _getitem_() : devuelve una muestra del índice dado del conjunto de datos.

Python3

# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
  
# defining the Dataset class
class data_set(Dataset):
    def __init__(self):
        numbers = list(range(0, 100, 1))
        self.data = numbers
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]
  
  
dataset = data_set()
  
# implementing dataloader on the dataset and printing per batch
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
for i, batch in enumerate(dataloader):
    print(i, batch)

Producción:

Cargadores de datos en conjuntos de datos integrados:

Python3

# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import seaborn as sns
from torch.utils.data import TensorDataset
  
# defining the dataset consisting of 
# two columns from iris dataset
iris = sns.load_dataset('iris')
petal_length = torch.tensor(iris['petal_length'])
petal_width = torch.tensor(iris['petal_width'])
dataset = TensorDataset(petal_length, petal_width)
  
# implementing dataloader on the dataset 
# and printing per batch
dataloader = DataLoader(dataset, 
                        batch_size=5, 
                        shuffle=True)
  
for i in dataloader:
    print(i)

Producción:

Publicación traducida automáticamente

Artículo escrito por d2anubis y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *