PyTorch torch.permute() reorganiza el tensor original según el orden deseado y devuelve un nuevo tensor rotado multidimensional. El tamaño del tensor devuelto sigue siendo el mismo que el del original.
Sintaxis: torch.permute(*dims)
Parámetros:
- dims: secuencia de índices en el orden deseado de dimensiones del tensor (la indexación comienza desde cero).
Retorno: tensor con orden de dimensiones deseado.
Veamos este concepto con la ayuda de algunos ejemplos:
Ejemplo 1: Crear un tensor bidimensional de tamaño 2 × 4 y luego permutarlo.
Python3
# import pytorch library import torch # create a tensor of size 2 x 4 input_var = torch.randn(2,4) # print size print(input_var.size()) print(input_var) # dimensions permuted input_var = input_var.permute(1, 0) # print size print(input_var.size()) print(input_var)
Producción:
torch.Size([2, 4]) tensor([[ 0.9801, 0.5296, 0.5449, -1.1481], [-0.6762, -0.1161, 0.6360, -0.5371]]) torch.Size([4, 2]) tensor([[ 0.9801, -0.6762], [ 0.5296, -0.1161], [ 0.5449, 0.6360], [-1.1481, -0.5371]])
Ejemplo 2: Crear un tensor tridimensional de tamaño 3 × 5 × 2 y luego permutarlo.
Python3
# import pytorch library import torch # creating a tensor with random # values of dimension 3 X 5 X 2 input_var = torch.randn(3, 5, 2) # print size print(input_var.size()) print(input_var) # dimensions permuted input_var = input_var.permute(2, 0, 1) # print size print(input_var.size()) print(input_var)
Producción:
torch.Size([3, 5, 2]) tensor([[[ 0.2059, -0.7165], [-1.1305, 0.5886], [-0.1247, -0.4969], [-0.5788, 0.0159], [ 1.4304, 0.6014]], [[ 2.4882, -0.3910], [-0.5558, 0.6903], [-0.4219, -0.5498], [-0.5346, -0.0703], [ 1.1497, -0.3252]], [[-0.5075, 0.5752], [ 1.3738, -0.3321], [-0.3317, -0.9209], [-1.6677, -1.1471], [-0.9269, -0.6493]]]) torch.Size([2, 3, 5]) tensor([[[ 0.2059, -1.1305, -0.1247, -0.5788, 1.4304], [ 2.4882, -0.5558, -0.4219, -0.5346, 1.1497], [-0.5075, 1.3738, -0.3317, -1.6677, -0.9269]], [[-0.7165, 0.5886, -0.4969, 0.0159, 0.6014], [-0.3910, 0.6903, -0.5498, -0.0703, -0.3252], [ 0.5752, -0.3321, -0.9209, -1.1471, -0.6493]]])