Llama desde cero: cómo implementar un paper sin llorar
(blog.briankitano.com)- Brian Kitano construye personalmente una versión reducida de Llama con TinyShakespeare, y resume que, para implementar un paper de forma segura, conviene partir de un modelo pequeño, reemplazar las piezas una por una y entrenar y evaluar en cada paso
- Primero prepara funciones auxiliares de verificación como división de datos, generación de batches, evaluación de pérdida y función de generación; luego confirma con un modelo simple que pueda compilar y entrenar antes de agregar componentes de Llama
- Al incorporar RMSNorm, RoPE y SwiGLU en orden, verifica mediante shapes de tensores, propiedades matemáticas y mapas de atención que cada capa funcione como se espera
- En la atención RoPE, al quitar la causal mask, la pérdida de validación baja hasta 0.16, pero la calidad de generación empeora; la causa fue una fuga de información al mirar tokens futuros
- La versión reducida final de Llama tiene 4 bloques y unos 2.37 millones de parámetros, reduce la pérdida de validación hasta cerca de 1.0 y también requiere revisar el flujo de gradients y el schedule de learning rate
Empezar pequeño y ganar confianza iterativamente
- La clave para implementar un paper es empezar con un modelo pequeño, cambiar los componentes uno por uno y repetir entrenamiento y evaluación en cada cambio
- Primero se preparan funciones auxiliares para verificar cuantitativamente el modelo
- División de datos
- Loop de entrenamiento
- Visualización de pérdida
- Evaluación de pérdida de validación
- En vez de trasladar todos los componentes del paper de una sola vez, también se prepara una función de evaluación cualitativa para ver los resultados generados con un modelo simple, rápido y ya conocido
- Las capas de tensores se verifican con
.shape,assertyplt.imshow; antes de optimizar multiplicaciones de matrices desde el inicio, se comprueban a mano los resultados esperados y luego se eficientiza con funciones detorch - Hay que probar cambiando batch size, longitud de secuencia y dimensión de embedding; el código que solo funciona con un tamaño puede romperse en tiempo de inferencia
Dataset y configuración básica
- El objetivo de implementación es una versión muy reducida de Llama de Meta AI, y los datos de entrenamiento son TinyShakespeare
- Llama se entrena con 1.4T tokens, pero aquí se usa TinyShakespeare, de alrededor de 1.11 millones de caracteres
- El Llama original usa un tokenizer byte-pair encoding de SentencePiece, pero esta implementación usa un tokenizer simple a nivel de caracteres
- el vocabulary size es 65
- como el dataset es pequeño, no se optimiza por separado la forma de almacenarlo en memoria
- Con el diccionario
MASTER_CONFIGse administran configuraciones del modelo comovocab_size,batch_size,context_windowyd_model- El objetivo es reducir constantes y magic numbers, y hacer el código más legible
- La función
get_batchesdivide los datos en train 80%, val 10% y test 10%, y desde un punto inicial aleatorio genera la entradaxy la etiquetayde un carácter posterior
Verificar compilación y entrenamiento con un modelo básico
- El primer modelo es
SimpleBrokenModel, compuesto por embeddings y una red feed-forward simplenn.EmbeddingLinearReLULinear
- En la implementación de un paper, decir que un modelo “funciona” implica cumplir ambas condiciones
- Compila: los shapes de tensores encajan entre capas
- Entrena: la pérdida realmente baja
- La función
evaluate_losssamplea 10 batches en los splits train y val, y calcula la pérdida promedio - Tras entrenar 1000 epochs,
SimpleBrokenModeltenía una pérdida de validación de alrededor de 3.94, casi sin bajar desde la cross-entropy inicial de 4.17 - La causa fue pasarle a
F.cross_entropyvalores ya procesados con softmaxF.cross_entropyde PyTorch recibe directamente logits no normalizadosSimpleModel, tras quitar softmax, reduce la pérdida de validación hasta alrededor de 2.51
- Luego se agrega la función
generatepara revisar directamente los caracteres creados por el modelo, y el modelo básico queda en un estado imperfecto, pero con pérdida de validación descendente
Componente 1 de Llama: RMSNorm
- Comparado con el Transformer original, Llama usa tres modificaciones principales de arquitectura
- RMSNorm pre-normalization
- Rotary embeddings
- SwiGLU activation function
- El Transformer original usa BatchNormalization, pero Llama usa RMSNorm, que escala el vector por la variance sin centering
- Mientras el Transformer original aplica normalization a la salida de la attention layer, en un esquema de post-normalization, Llama usa pre-normalization, aplicándola primero a la entrada
- La implementación de
RMSNormasume un shape de entrada(batch, seq_len, d_model) - El resultado de RMSNorm se prueba con la propiedad de que la layer norm se vuelve la raíz cuadrada del número de elementos de la capa
assert- row-wise comparison
torch.allclose
SimpleModel_RMS, que agrega RMSNorm al modelo básico, reduce ligeramente la pérdida de validación hasta alrededor de 2.5015
Componente 2 de Llama: RoPE y causal mask
- RoPE es un método de positional encoding para Transformers, que expresa la posición de los tokens como una rotación del embedding
get_rotary_matrixgenera matrices de rotación por posición para el context window y la embedding dimension- La implementación de RoPE se prueba con la siguiente propiedad
- El producto interno de dos vectores rotados en las posiciones
myndebe coincidir con la rotación de posición relativan-m
- El producto interno de dos vectores rotados en las posiciones
RoPEAttentionHeadcreaw_q,w_kyw_v, aplica la rotación RoPE a query y key, y luego usaF.scaled_dot_product_attention- Hay que cuidar la diferencia de shapes de tensores entre entrenamiento e inferencia
- En entrenamiento, muchas veces coinciden con la configuración, por ejemplo
(config['batch_size'], config['context_window'], config['d_model']) - En inferencia, puede procesarse un único ejemplo como
(1, 1, config['d_model']) - Dentro de
forward, hay que indexar con base en el shape obtenido desde la entrada, no en los valores de configuración del modelo
- En entrenamiento, muchas veces coinciden con la configuración, por ejemplo
- El modelo que agrega RoPE multi-head attention sin causal mask reduce drásticamente la pérdida de validación hasta 0.1623, pero los resultados generados son malos, como
OOOO...oIIII... - Al revisar el attention map, todas las posiciones referenciaban a todas las posiciones, y en la predicción del siguiente token ocurría una fuga de información al mirar tokens futuros
- Al cambiar a
RoPEMaskedAttentionHeadaplicandois_causal=TrueenF.scaled_dot_product_attention, la attention upper triangular correspondiente al futuro queda casi en 0 - Tras aplicar la causal mask, la pérdida de validación queda en 2.0815, y al entrenar por más tiempo baja hasta 1.8985
Componente 3 de Llama: SwiGLU y apilar bloques
- Llama reemplaza la no linealidad ReLU por la SwiGLU activation function
- La implementación de
SwiGLUes una Swish-gated linear unit, y usa dos transformaciones lineales y un parámetrobetalearnable - RopeModel, con SwiGLU en la parte feed-forward, tiene 592,706 parámetros y una pérdida de validación de alrededor de 1.8963
- Luego se crea
LlamaBlockpara agrupar la siguiente configuración en un bloque- RMSNorm pre-normalization
- masked RoPE multi-head attention
- residual connection
- RMSNorm pre-normalization
- SwiGLU feed-forward
- residual connection
- El modelo final
Llamase configura conn_layers=4y apila 4LlamaBlockconnn.Sequentialbasado enOrderedDict - El modelo final tiene 2,370,246 parámetros, y los resultados de entrenamiento son los siguientes
- Pérdida de validación de 1.5532 tras el entrenamiento inicial de 4 layers
- Pérdida de validación de 1.1479 tras seguir entrenando durante 10,000 epochs
- Pérdida de validación de 0.9997 tras entrenamiento adicional
- La pérdida de un batch del split test es 1.2358
Resultados de generación y checklist de debugging
- El modelo final crea nombres, saltos de línea y fragmentos de palabras similares al formato de Shakespeare, pero la calidad real de las oraciones es limitada
- La pérdida de cross-entropy puede interpretarse intuitivamente desde la perspectiva de selección de tokens
- La pérdida inicial de 4.17 se acerca a una elección aleatoria con vocabulary size 65
- Una pérdida de 1.08 se interpreta como un nivel equivalente a elegir aleatoriamente entre unos 2.9 tokens
- El flujo de gradients se revisa con la función
show_grads- Calcula la proporción de gradients con valor absoluto pequeño en cada parámetro
- Si los gradients de la mayoría de los parámetros no están cerca de 0, el flujo está en buen estado
- El Llama original usa un learning schedule de Cosine Annealing, pero en esta implementación los resultados experimentales fueron peores
- En el experimento con Cosine Annealing, incluso con una tolerance muy baja, el attention bias casi no recibía señal; como el motivo no está claro, en una implementación real es más seguro empezar de forma simple
1 comentarios
Opiniones en Hacker News
Parece haber un bug en la implementación de SwiGLU: en el paper de referencia, beta de la feed-forward network no es un valor entrenable sino una constante, y se define como
FFnSwiGLU = Swish1...Según la ecuación 6 de https://arxiv.org/pdf/2002.05202.pdf
En la implementación oficial de llama también se eliminó la beta constante: https://github.com/facebookresearch/llama/blob/main/llama/mo...
Si se ven las líneas
"feedforward.1.beta', 0.0"del log del blog, durante el entrenamiento beta degeneró a 0, cuando originalmente debería ser la constante 1Muchas veces la red se adapta a los cambios, sean intencionales o no, y después del entrenamiento varias variantes de arquitectura también se comportan de forma similar, por lo que a veces es ambiguo si realmente debe coincidir con la original
Una forma de encontrar estos errores es hacer coincidir exactamente los valores de salida con una implementación de referencia. Incluso con pesos aleatorios, como en los modelos tiny-random de HuggingFace, la salida debería ser exactamente la misma; si no lo es, es señal de bug
Sin embargo, este método funciona bien sobre todo para bugs que aparecen durante la inferencia; los problemas que solo ocurren en el procesamiento de datos, el optimizador o durante el entrenamiento son más difíciles de detectar
Personalmente pienso que es por su naturaleza autorregresiva y similar a una ODE, pero no estoy tan seguro como para afirmarlo
El trabajo es excelente, pero los primeros
SimpleBrokenModelySimpleModeltienen bastantes operaciones desperdiciadas. El orden esembedding 65 -> 128,linear 128 -> 128,ReLU,linear 128 -> 65; como no hay no linealidad entre las dos primeras capas y ambas son lineales, la segunda capa lineal en realidad no sirve de muchoEste modelo termina siendo equivalente a un MLP clásico de una sola capa oculta y, en términos de FLOPS, desperdicia
128*128=16koperaciones de un total de128*128+65*128=24kLa capa de embedding es una estructura especial que convierte índices de tokens en vectores de embedding, así que no creo que se pueda eliminar
En general muestra muy bien los principios básicos. En especial me gusta la frase “usa
.shapereligiosamente.assertyplt.imshowson tus amigos”, y las precondiciones y poscondiciones de shape siempre deberían verificarse con assertTambién me pregunto si
bearotypeguardsoportan este tipo de validaciones mediante decoradoresPero en la parte de “elige un modelo pequeño, simple y rápido, y crea helpers para evaluarlo cualitativamente”, sospecho que en realidad se refiere a una evaluación cuantitativa. Así se obtiene una línea base numérica para comparar con técnicas más avanzadas
El consejo de implementar uno por uno los componentes del paper también debería ser más preciso. Los papers normalmente prueban varios cambios a la vez y luego muestran la contribución de cada elemento mediante estudios de ablación; por eso creo que es mejor empezar por los cambios centrales de arquitectura y evaluar cada cambio atómico, respetando las dependencias, en el orden de mayor impacto observado en las ablaciones
bearotypeguard, gracias a https://peps.python.org/pep-0646/, parte de esto se puede empujar directamente a las anotaciones de tipo de PythonPor ejemplo, se puede expresar el shape por eje en el tipo con algo como
ndarray[float, Dim1, *Shape], y sobrecargar el shape de retorno según el valor deaxisbear/typeguardAun así, parece difícil que Python sea tan bueno como Julia. El sistema de tipos de Julia permite garantizar con mucha más facilidad que los tamaños de las matrices coincidan
Me pregunto cuál es el principio para usar SwiGLU en lugar de ReLU. No sé si los autores simplemente probaron todas las funciones no lineales posibles o si hay una razón más profunda
Como bearblog está recibiendo un DDoS, dejo el repositorio: https://github.com/bkitano/llama-from-scratch
Desde la perspectiva de alguien que está aprendiendo IA, hice un resumen simple de los términos que aparecen en el texto. Un token es un identificador entero que representa un fragmento de texto y, en los LLM, se agrupan fragmentos de caracteres de uso frecuente dentro de un tamaño de vocabulario limitado
La función de pérdida es un valor que mide la diferencia entre la predicción y la respuesta correcta, y mientras más bajo sea, mejor. PyTorch es una biblioteca para trabajar con tensores y redes neuronales, y un tensor es un arreglo multidimensional de números que incluye escalares, vectores y matrices
Una red neuronal es una estructura de conexiones de neuronas con pesos y sesgos, y una capa lineal es una estructura simple en la que todas las entradas y salidas están conectadas. ReLU es una función de activación como
Math.max(0, x): si solo apilas capas lineales, al final equivale a una sola función lineal, así que se introduce no linealidad para aumentar la capacidad de aprendizajeEl gradiente es una cantidad de cambio numérico que se calcula durante el entrenamiento para hacer que el modelo sea más preciso, y la normalización por lotes es un método que ayuda al entrenamiento ajustando los números que fluyen. La codificación posicional le indica al modelo las posiciones relativas de los tokens mediante vectores
El operador
@de Python es un alias de__matmul__y se usa para multiplicación de matrices. Una época es entrenar una vez sobre todo el dataset, y un lote es la cantidad de datos que se ingresan de una sola vez antes de actualizar los parámetrosLa atención es el componente central que hace funcionar a los LLM: procesa los tokens de entrada en paralelo para crear tensores intermedios y luego los usa para generar tokens de salida
Por ejemplo,
writ, que aparece en común enwriting,writtenywriter, puede ser un token, ywriterpuede tokenizarse comowrityerEl embedding es la etapa que convierte esos tokens en representaciones numéricas propias
Si existe una implementación previa del modelo y checkpoints, la forma más efectiva de comprobar si tu implementación es correcta es cargar ese checkpoint y comparar los valores de salida
Si la salida no coincide, por lo general significa que algún detalle de implementación está mal, y puedes seguir sistemáticamente cada capa para encontrar la diferencia real. En el camino, incluso podrías descubrir algo extraño en la implementación existente
Esto se refiere al modelo en sí; el entrenamiento es otro eje aparte. Aun así, si ajustaste los hiperparámetros de forma más o menos similar, cuando la implementación del modelo es correcta, en general las cosas salen bien
Tanto la forma de leer papers como el contenido de ese paper están muy buenos, y también recomiendo la serie Makemore de Karpathy
Los consejos resumidos son muy buenos, y creo que el de hacer assert de los shapes de los tensores aplica a cualquier biblioteca general de álgebra lineal. Al escribir código complejo de álgebra lineal, es muy importante avanzar en pasos pequeños y programar de forma defensiva
Programar álgebra lineal en lenguajes mainstream es terrible porque no hay verificación de shapes en tiempo de compilación. El shape de un tensor debería formar parte del tipo, y si intentas multiplicar
3x4por3x4sin transponer, ni siquiera debería compilarEs realmente lo peor correr un cálculo largo y que luego falle por una operación con dimensiones incompatibles
También creo que los tensores de PyTorch deberían tener el dispositivo tipado estáticamente. Hoy, si intentas multiplicar un tensor en memoria de CPU por uno en memoria de GPU, aparece un error en tiempo de ejecución