Retropropagación a través del tiempo – RNN

Introducción:
Las Redes Neuronales Recurrentes son aquellas redes que tratan con datos secuenciales. Predicen las salidas utilizando no solo las entradas actuales, sino también teniendo en cuenta las que ocurrieron antes. En otras palabras, la salida actual depende de la salida actual, así como de un elemento de memoria (que tiene en cuenta las entradas anteriores).
Para entrenar este tipo de redes, usamos la buena y antigua retropropagación pero con un ligero giro. No entrenamos el sistema de forma independiente en un tiempo específico «t» . Lo entrenamos en un tiempo específico “t” así como todo lo que ha pasado antes del tiempo “t” como t-1, t-2, t-3.

Considere la siguiente representación de un RNN:

Arquitectura RNN

S1, S2, S3 son los estados ocultos o unidades de memoria en el tiempo t1, t2, t3 respectivamente, y Ws es la array de peso asociada con él.
X1, X2, X3 son las entradas en el tiempo t1, t2, t3 respectivamente, y Wx es la array de ponderación asociada.
Y1, Y2, Y 3 son las salidas en el tiempo t1, t2, t3 respectivamente, y Wy es la array de ponderación asociada.
Para cualquier tiempo, t, tenemos las siguientes dos ecuaciones:

 \begin{equation*} S_{t} = g_{1}(W_{x}x_{t} + W_{s}S_{t-1})                     \end{equation*} \begin{equation*}                     Y_{t} = g_{2}(W_{Y}S_{t})                         \end{equation*}

donde g1 y g2 son funciones de activación.
Realicemos ahora la propagación hacia atrás en el tiempo t = 3.
Sea la función de error:

 \begin{equation*} E_{t} = (d_{t} - Y_{t})^{2} \end{equation*}

, entonces en t = 3,

 \begin{equation*}  E_{3} = (d_{3} - Y_{3})^{2}                         \end{equation*}

*Aquí estamos usando el error cuadrático, donde d3 es la salida deseada en el tiempo t = 3 .
Para realizar la propagación hacia atrás, tenemos que ajustar los pesos asociados con las entradas, las unidades de memoria y las salidas.
Ajuste de Wy
Para una mejor comprensión, consideremos la siguiente representación:

Ajuste Wy

Fórmula:

 \begin{equation*} \frac{\partial E_{3}}{\partial W_{y}} = \frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial W_{Y}} \end{equation*}

Explicación:
E3 es una función de Y3 . Por lo tanto, diferenciamos E3 wrt Y3 .
Y3 es una función de WY . Por lo tanto, diferenciamos Y3 wrt WY .

Ajuste de Ws
Para una mejor comprensión, consideremos la siguiente representación:

Ajuste de W

Fórmula:

 \begin{equation*}      \frac{\partial E_{3}}{\partial W_{S}} = (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial W_{S}})     +   \end{equation*}

 \begin{equation*}     (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial S_{2}} . \frac{\partial S_{2}}{\partial W_{S}})      +  \end{equation*}

 \begin{equation*}      (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial S_{2}} . \frac{\partial S_{2}}{\partial S_{1}} . \frac{\partial S_{1}}{\partial W_{S}})   \end{equation*}

Explicación:
E3 es una función de Y3 . Por lo tanto, diferenciamos E3 wrt Y3 .
Y3 es una función de S3 . Por lo tanto, diferenciamos Y3 wrt S3 .
S3 es una función de WS . Por lo tanto, diferenciamos S3 wrt WS .
Pero no podemos quedarnos con esto; también tenemos que tener en cuenta, los pasos de tiempo anteriores. Entonces, diferenciamos (parcialmente) la función Error con respecto a las unidades de memoria S2 y S1 teniendo en cuenta la array de peso WS .
Tenemos que tener en cuenta que una unidad de memoria, digamos S t, es una función de su unidad de memoria anterior S t-1 .
Por lo tanto, diferenciamos S3 con S2 y S2 con S1 .
En general, podemos expresar esta fórmula como:

 \begin{equation*}  \frac{\partial E_{N}}{\partial W_{S}} = \sum_{i=1}^{N} \frac{\partial E_{N}}{\partial Y_{N}} . \frac{\partial Y_{N}}{\partial S_{i}} . \frac{\partial S_{i}}{\partial W_{S}}  \end{equation*}

Ajuste de WX:
Para una mejor comprensión, consideremos la siguiente representación:

Ajuste de Wx

Fórmula:

 \begin{equation*}      \frac{\partial E_{3}}{\partial W_{X}} = (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial W_{X}})     +   \end{equation*}

 \begin{equation*}     (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial S_{2}} . \frac{\partial S_{2}}{\partial W_{X}})      +   \end{equation*}

 \begin{equation*}      (\frac{\partial E_{3}}{\partial Y_{3}} . \frac{\partial Y_{3}}{\partial S_{3}} . \frac{\partial S_{3}}{\partial S_{2}} . \frac{\partial S_{2}}{\partial S_{1}} . \frac{\partial S_{1}}{\partial W_{X}})   \end{equation*}

Explicación:
E3 es una función de Y3 . Por lo tanto, diferenciamos E3 wrt Y3 .
Y3 es una función de S3 . Por lo tanto, diferenciamos Y3 wrt S3 .
S3 es una función de WX . Por lo tanto, diferenciamos S3 wrt WX .
Una vez más, no podemos detenernos con esto; también tenemos que tener en cuenta, los pasos de tiempo anteriores. Por lo tanto, diferenciamos (parcialmente) la función de error con respecto a las unidades de memoria S2 y S1 teniendo en cuenta la array de peso WX.
En general, podemos expresar esta fórmula como:

 \begin{equation*}  \frac{\partial E_{N}}{\partial W_{S}} = \sum_{i=1}^{N} \frac{\partial E_{N}}{\partial Y_{N}} . \frac{\partial Y_{N}}{\partial S_{i}} . \frac{\partial S_{i}}{\partial W_{X}}  \end{equation*}

Limitaciones:
este método de propagación hacia atrás a través del tiempo (BPTT) se puede usar hasta un número limitado de pasos de tiempo como 8 o 10. Si retropropagamos más, el gradiente se \deltavuelve demasiado pequeño. Este problema se llama el problema del «gradiente de fuga». El problema es que la contribución de la información decae geométricamente con el tiempo. Entonces, si el número de pasos de tiempo es > 10 (digamos), esa información se descartará efectivamente.

Ir más allá de las RNN:
una de las soluciones famosas a este problema es usar lo que se llama celdas de memoria a largo plazo a corto plazo (LSTM, por sus siglas en inglés) en lugar de las celdas RNN tradicionales. Pero podría surgir otro problema aquí, llamado el problema del gradiente explosivo , donde el gradiente crece incontrolablemente.
Solución: se puede usar un método popular llamado recorte de gradiente en el que, en cada paso de tiempo, podemos verificar si el gradiente es \delta> umbral. Si es así, entonces normalízalo.

Publicación traducida automáticamente

Artículo escrito por KeshavBalachandar y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *