Explicación de cómo funcionan los transformers: entender las matemáticas detrás
(osanseviero.github.io)- Reduce el proceso de inferencia de un transformer a un ejemplo de traducción Hello World → Hola Mundo, para poder seguir a mano desde la tokenización hasta el encoder, el decoder y el cálculo de probabilidades del siguiente token
- En lugar de la configuración grande del paper original, usa embeddings de 4 dimensiones, 2 cabezas de attention y una capa feedforward de 8 dimensiones para hacer más pequeño el flujo de multiplicaciones de matrices y softmax
- El encoder suma codificación posicional a los embeddings de tokens y luego pasa por self-attention multihead y una capa feedforward para crear una representación contextual de la secuencia de entrada
- El decoder empieza en
SOS, usa tanto los tokens generados previamente como la salida del encoder, y en encoder-decoder attention la query se calcula desde el decoder, mientras que key/value se calculan desde la salida del encoder - El último embedding del decoder pasa por una capa lineal y softmax para convertirse en probabilidades del siguiente token, pero como el ejemplo usa pesos aleatorios, no hay que esperar calidad real de traducción
Objetivo y premisas
- Verifica con un ejemplo end-to-end cómo se conectan las matemáticas durante la inferencia dentro de un modelo transformer
- Reduce mucho el tamaño del modelo para que sea fácil seguir los cálculos a mano
- En lugar de la dimensión de embedding 512 del paper original, el ejemplo usa 4 dimensiones
- En lugar de las 8 cabezas de attention del paper original, usa 2 cabezas
- En lugar de la dimensión feedforward 2048 del paper original, usa 8 dimensiones
- La premisa necesaria es álgebra lineal básica, y la mayoría de los cálculos se hacen con multiplicación de matrices
- Se enfoca menos en “qué es” un transformer y más en cómo avanzan los cálculos reales
- Para una explicación intuitiva, conviene leerlo junto con The Illustrated Transformer, y el paper original es Attention is all you need
Crear la entrada del encoder
-
Tokenización
- Como los modelos de machine learning procesan números y no texto, el texto de entrada se convierte en IDs de tokens
- Para simplificar, el ejemplo divide
"Hello World"en dos tokens de palabra:"Hello"y"World" - Los métodos reales de tokenización pueden ser basados en palabras, en caracteres o en subwords
- El enfoque basado en palabras requiere un vocabulario grande y trata
"dog"y"dogs"como tokens distintos - El enfoque basado en caracteres tiene un vocabulario pequeño, pero puede contener menos información semántica
- La tokenización por subwords está en un punto intermedio entre palabras y caracteres, y entrena el tokenizador mediante un proceso estadístico
-
Embeddings de tokens
- Como el ID de token en sí no tiene significado, cada token se transforma en un embedding, un vector de tamaño fijo
- Los embeddings del ejemplo usan valores arbitrarios
Hello -> [1, 2, 3, 4]World -> [2, 3, 4, 5]
- En un transformer real, el mapeo de embeddings también se aprende, y el modelo aprende representaciones de tokens adecuadas para la tarea
- Los dos embeddings se agrupan en una sola matriz y se usan después en multiplicaciones de matrices
-
Codificación posicional
- Solo con embeddings no se puede saber la posición dentro de la oración de cada palabra, por lo que se suma una codificación posicional
- El paper original usa codificación posicional fija con sine/cosine, y el ejemplo sigue el mismo método
- La codificación posicional del ejemplo se calcula así
Hello -> [0, 1, 0, 1]World -> [0.84, 0.99, 0, 1]
- Al sumar los embeddings de tokens y la codificación posicional se crea la matriz de entrada del encoder
Hello -> [1, 3, 3, 5]World -> [2.84, 3.99, 4, 6]
Cálculo de self-attention
-
Crear Q, K, V
- Self-attention calcula query(Q), key(K) y value(V) a partir de los embeddings de entrada
- El ejemplo usa 2 cabezas de attention, y cada cabeza tiene sus propias matrices
WQ,WK,WV - Cada matriz de pesos transforma el embedding de 4 dimensiones en query/key/value de 3 dimensiones
- En la primera cabeza, se multiplica la matriz de entrada por
WK1,WV1,WQ1para obtenerK1,V1,Q1
-
Fórmula de attention
- Los scores de attention se calculan en cuatro pasos
- Se calcula el producto punto entre la query y cada key
- Se divide por la raíz cuadrada de la dimensión de las keys
- Se convierte con softmax en pesos positivos cuya suma es 1
- Se calcula la suma ponderada de los vectores value usando esos pesos
- Este proceso se resume en la fórmula del paper original
- [
- Attention(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V
- ]
- En el ejemplo, por las dimensiones pequeñas y los valores iniciales arbitrarios, el resultado de softmax queda sesgado casi a 0 y 1
- Los valores grandes de dot product pueden amplificarse más en softmax, por lo que hace falta escalar dividiendo por la raíz cuadrada de la dimensión de las keys
- Para la explicación también se usa temporalmente una variante que divide por 30 en lugar de
sqrt(3), pero no es una solución de largo plazo
- Los scores de attention se calculan en cuatro pasos
-
Salida de multihead attention
- Los resultados de attention de cada cabeza se concatenan y luego se multiplican por una matriz de pesos aprendida para volver a la dimensión de embedding
- En el ejemplo se combinan los resultados de dos cabezas para crear una matriz de 6 dimensiones, que luego se transforma en una salida de 4 dimensiones
- Esta salida se pasa a la siguiente etapa del bloque encoder: la capa feedforward
Capa feedforward y bloque encoder
-
Capa feedforward
- Después de self-attention hay una red neuronal feedforward (FFN)
- La FFN está compuesta por dos transformaciones lineales y una activación ReLU entre ellas
- La primera capa lineal expande la dimensión, y la segunda la reduce de nuevo al tamaño original
- ReLU convierte los valores negativos en 0 y deja los positivos igual, agregando no linealidad
- En el ejemplo, una entrada de 4 dimensiones se expande a 8 dimensiones y luego se reduce de nuevo a 4
- [
- \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
- ]
-
Bloque encoder
- Un bloque encoder está compuesto por multihead attention y una FFN
- El paper original apila 6 encoders, y el código del ejemplo también repite el encoder con
n=6 - Si simplemente se pasan los valores por varios bloques encoder, pueden crecer demasiado, causar overflow en el cálculo de softmax y producir
nan
Residual connection y layer normalization
-
Problema de explosión de valores
- En el ejemplo, al pasar por 6 encoders aparecen las advertencias
overflow encountered in expeinvalid value encountered in divide, y la salida se vuelvenan - Que los valores crezcan demasiado y se hagan aún más grandes en las capas siguientes es un problema común en redes neuronales profundas
- Cuando el gradiente crece demasiado durante backpropagation, se llama gradient explosion
- En el ejemplo, al pasar por 6 encoders aparecen las advertencias
-
Residual connection
- Una residual connection suma la entrada de una capa a la salida de esa capa
- [
- \text{Residual}(x) = x + \text{Layer}(x)
- ]
- En el ejemplo, se aplica una residual connection tanto a la salida de attention como a la salida de la FFN
- Las residual connections se usan para mitigar el problema de vanishing gradient
-
Layer normalization
- Layer normalization normaliza cada dimensión del embedding para que tenga media 0 y desviación estándar 1
- La fórmula es la siguiente
- [
- \text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \times \gamma + \beta
- ]
- (\epsilon) es un valor pequeño para evitar la división por cero cuando la desviación estándar es 0
- (\gamma) y (\beta) son parámetros aprendibles que controlan el scaling y el shifting
- Después de agregar residual connection y layer normalization, incluso al pasar por 6 encoders se obtienen valores normales sin
nan
Estructura del decoder
-
Entrada del decoder y forma de generación
- El decoder recibe como entrada la salida del encoder y la secuencia de salida generada hasta el momento
- Durante la inferencia, empieza con el token SOS(start-of-sequence)
- El decoder genera tokens de a uno de forma autoregresiva
- 1.ª iteración: recibe
SOScomo entrada y genera"hola" - 2.ª iteración: recibe
SOS + holacomo entrada y genera"mundo" - 3.ª iteración: recibe
SOS + hola + mundocomo entrada y generaEOS
- 1.ª iteración: recibe
- Cuando se genera el token
EOS(end-of-sequence), se detiene el decoding - El encoder puede crear representaciones con un solo forward pass, pero el decoder es más lento porque necesita varios forward passes
-
Componentes del bloque decoder
- El bloque decoder es más complejo que el bloque encoder y está compuesto en este orden
- masked self-attention
- residual connection y layer normalization
- encoder-decoder attention
- residual connection y layer normalization
- capa feedforward
- residual connection y layer normalization
- En el ejemplo de inferencia se suma codificación posicional al embedding de
SOSy se usa[1, 1, 0, 1] - Durante el entrenamiento, se usa masked self-attention, que enmascara los scores de attention con
-infpara impedir ver tokens futuros
- El bloque decoder es más complejo que el bloque encoder y está compuesto en este orden
Encoder-decoder attention
- Encoder-decoder attention es la etapa que permite al decoder concentrarse en las partes relevantes de la oración de entrada
- El cálculo es igual que en self-attention, pero la entrada usada para crear Q/K/V es distinta
- La query se calcula a partir de la salida de la capa anterior del decoder
- La key y el value se calculan a partir de la salida del encoder
- Gracias a esta estructura, cada posición del decoder puede consultar todas las posiciones de la secuencia de entrada
- Es útil para tareas como traducción, donde los tokens de salida dependen de posiciones relevantes de la oración de entrada
Generación del token de salida
-
Linear layer y softmax
- Como la salida del decoder no es una palabra directamente, el último embedding se pasa por una capa lineal para convertirlo en un vector de logits del tamaño del vocabulario
- El tamaño del vocabulario del ejemplo es 10, y los candidatos para el siguiente token son los siguientes
hello,mundo,world,how,?,EOS,SOS,a,hola,c
- Los logits pasan por softmax y se vuelven una distribución de probabilidad sobre cada token
- En las probabilidades del ejemplo,
"hola"tiene la probabilidad más alta, por lo que se elige como siguiente token - Elegir siempre el token con mayor probabilidad es greedy decoding, y no siempre es lo óptimo
- Las técnicas de generación se pueden ver con más detalle en el artículo de Hugging Face
-
Bucle completo de generación
- El procedimiento completo de generación sigue este flujo
- Convierte la secuencia de entrada en embeddings
- El encoder genera una representación contextual de toda la entrada
- El decoder empieza en
SOSy usa tanto los tokens generados previamente como la salida del encoder - Aplica linear layer y softmax al último embedding del decoder
- Elige el siguiente token más probable y lo agrega a la secuencia
- Repite hasta que aparezca
EOSo se alcance la longitud máxima
- La ejecución de ejemplo genera
SOS hola mundo worldpara la entradahello world - Como todos los pesos y embeddings se usan de forma aleatoria, el resultado no es una buena traducción, y eso es lo esperado
- El procedimiento completo de generación sigue este flujo
Conclusión y alcance
- El ejemplo conecta en un solo flujo los componentes centrales del transformer: embeddings, codificación posicional, self-attention, multihead attention, FFN, residual connection, layer normalization, encoder-decoder attention y salida con softmax
- Las arquitecturas transformer modernas agregan varias técnicas, pero las matemáticas centrales se basan en la estructura tratada en este ejemplo
- La pila usada puede variar según el tipo de tarea
- Para tareas centradas en comprensión, como clasificación, se puede colocar una linear layer encima de la pila de encoders
- Para tareas centradas en generación, como traducción, se pueden usar juntas las pilas de encoder y decoder
- Para tareas de generación libre como ChatGPT o Mistral, se puede usar solo la pila de decoders
- No cubre el proceso de entrenamiento, sino que se enfoca en entender las matemáticas de inferencia al usar un modelo existente
- Para material matemático más formal, se puede consultar este PDF
1 comentarios
Opiniones de Hacker News
El “misterio” de Transformer está en que, en lugar de multiplicar pesos y valores estáticos en orden lineal en cada capa, crea 3 matrices obtenidas al multiplicar la misma entrada por pesos aprendidos, y luego multiplica esas matrices entre sí.
Funciona bien porque aumenta el paralelismo, pero la fórmula de atención en sí es fija, así que es muy limitada.
Para avanzar más, parece necesario encontrar una forma de generalizar el propio grafo de cómputo como parámetros aprendibles. No sé si eso es posible con métodos tradicionales de gradiente por el efecto caótico donde pequeños cambios llevan a grandes cambios de rendimiento; quizá internamente haga falta algo como algoritmos genéticos u optimización por enjambre de partículas.
Frente a las RNN, la gran ventaja teórica es que permite esto sin pérdida. Cada elemento puede acceder a toda la información de todos los demás elementos de la secuencia o, en orden temporal, de todos los elementos anteriores.
En cambio, las RNN y los “Transformers lineales” comprimen los valores pasados, por lo que normalmente es difícil que el último elemento de una secuencia larga acceda a toda la información del primero; y es imposible a menos que el estado interno sea enorme y no descarte ninguna información.
El problema es que no se gana mucho con eso. Las operaciones que no sean multiplicación de matrices probablemente sean más lentas o tengan una velocidad similar.
Eso sí, si se incorpora control de flujo, existe el riesgo de que en la práctica se convierta en una máquina de Turing, y entonces, como se dijo, el entrenamiento se vuelve el problema. Aun así, quizá no sea un problema completamente intratable.
Si quieres una explicación más seca, formal y concisa, está “The Transformer Model in Equations” de John Thickstun [0].
En notación matemática estándar, todo cabe en una página.
[0] https://johnthickstun.com/docs/transformers.pdf
Muchas veces parece que los investigadores de machine learning nunca hubieran estudiado matemáticas.
La explicación de “aparece NaN, los valores son demasiado grandes y explotan al pasar al siguiente encoder; esto es explosión del gradiente” es incorrecta, según entiendo.
Aquí no se calcula ningún gradiente en ningún punto, así que no es explosión del gradiente.
El problema parece estar en la implementación de softmax, y aquí [0] se explica cómo implementar softmax de forma numéricamente estable.
[0]: https://jaykmody.com/blog/stable-softmax/
Aun así, toda la red neuronal es sensible a valores grandes, por lo que un softmax numéricamente estable no basta para resolverlo. Para que la red funcione, la normalización es clave.
Los tutoriales sobre Transformers quizá se conviertan en los nuevos tutoriales sobre Monad. Es un concepto difícil de entender, pero de esos que se entienden lidiando con ejemplos y practicando.
Como ocurre con muchas cosas en ciencias de la computación.
Leí apenas seis párrafos y ya tengo una pregunta.
En
Hello -> [1,2,3,4] World -> [2,3,4,5], se dice que los vectores son aleatorios, pero parece haber un patrón. Me pregunto si el2que aparece en ambos vectores significa algo, o si es el conjunto completo el que genera la unicidad.Aquí están separados por unos 60 grados y apuntan más o menos en la misma dirección, pero como se intentó no poner números negativos en el ejemplo, los vectores quedaron más parecidos de lo que serían en realidad.
El hecho de que se hayan reutilizado números no significa nada por sí mismo. El
1en la primera posición casi no tiene relación con un1en la segunda posición. Tampoco se está haciendo convolución sobre este vector.Después del entrenamiento, las palabras parecidas tendrán cierta similitud coseno, pero casi nunca una similitud coseno tan alta como la de
[1,2,3,4]y[2,3,4,5].No es una pregunta del todo relacionada, pero estoy buscando algún artículo o paper que trate por qué un Transformer puede manejar preguntas como las siguientes aun cuando funciona simplemente como un “predictor del siguiente token”:
"sdsfs_ff","fsdf_value"como columnasSiento que debe ser una pregunta común, pero no encuentro las palabras clave para buscarla. También me servirían enlaces que profundicen en embeddings posicionales, y todavía no encontré una respuesta satisfactoria sobre por qué se usan seno/coseno y sobre multiplicación vs. suma
Si el modelo lo considera necesario, puede reproducir una secuencia desconocida copiando tokens de caracteres individuales, o inventarla si tiene sentido en el contexto.
P(X_1=x_1, X_2=x_2, X_3=x_3) = P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_1=x_1, X_2=x_2)= P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_2=x_2 | X_1=x_1) • P(X_1=x_1)Es decir, si se tiene la distribución de probabilidad condicional correcta para el siguiente token dados los tokens anteriores, también se obtiene la distribución de probabilidad correcta para toda la secuencia de tokens.
Una “distribución de probabilidad correcta para una secuencia de tokens”, o la distribución de probabilidad condicional correcta de una secuencia de tokens dadas ciertas condiciones, en la práctica permite describir casi cualquier tipo de comportamiento de entrada/salida en esos términos.
Por eso, que “funcione prediciendo el siguiente token” no es, en principio, una gran restricción sobre qué comportamientos de entrada/salida puede realizar.
Por más impresionante que sea lo que haga, no contradice el hecho de que su salida provenga de
P(X_{n+1}=x_{n+1} | X_1=x_1, ..., X_n=x_n), es decir, de la “predicción del siguiente token”.Predecir el siguiente token es una tarea más inteligente de lo que parece.
Estoy de acuerdo con que “la complejidad viene de la cantidad de pasos y de parámetros”.
Los modelos Transformer lo bastante simples como para que los entendamos no hacen nada interesante, y los Transformers lo bastante complejos como para hacer cosas interesantes parecen demasiado complejos para que los entendamos.
Me gustaría investigar modelos de escala intermedia que sean lo bastante simples como para entenderlos, pero lo bastante complejos como para hacer cosas interesantes.
Si se usan conceptos sin definirlos ni introducirlos, es difícil entender. La sección Encoder arranca directamente sin explicar qué es ni en qué parte del proceso completo encaja.
Entiendo lo que el autor intenta hacer, pero falta la estructura básica de un texto: presentar primero las ideas, explicarlas y luego usarlas.
Si no sos alguien que ya está estudiando este tema y lo entiende a medias, todo el artículo se siente confuso.
Aunque alguna vez escribí una ANN desde cero y no usé TensorFlow, esta explicación igual me resulta confusa.
Le pedí a ChatGPT que explicara cómo cambiar una ANN básica para implementar self-attention sin usar las palabras
MatrixniVector, y me dio una explicación bastante simple. Todavía no lo implementé.Me resulta mejor pensar todo en términos de nodos, pesos y capas. Las matrices y los vectores hacen más difícil conectarlo con lo que realmente ocurre dentro de una ANN.
En la forma familiar de escribir una ANN, cada nodo de entrada es un escalar, pero el algoritmo de propagación hacia adelante multiplica todos los nodos de entrada por pesos y los suma, así que se parece a una multiplicación vector-matriz. Siento que estoy abordando estas explicaciones con la mentalidad equivocada, y puede que me falten conocimientos de base necesarios.