2 puntos por GN⁺ 2023-08-10 | 1 comentarios | Compartir por WhatsApp
  • Brian Kitano construye personalmente una versión reducida de Llama con TinyShakespeare, y resume que, para implementar un paper de forma segura, conviene partir de un modelo pequeño, reemplazar las piezas una por una y entrenar y evaluar en cada paso
  • Primero prepara funciones auxiliares de verificación como división de datos, generación de batches, evaluación de pérdida y función de generación; luego confirma con un modelo simple que pueda compilar y entrenar antes de agregar componentes de Llama
  • Al incorporar RMSNorm, RoPE y SwiGLU en orden, verifica mediante shapes de tensores, propiedades matemáticas y mapas de atención que cada capa funcione como se espera
  • En la atención RoPE, al quitar la causal mask, la pérdida de validación baja hasta 0.16, pero la calidad de generación empeora; la causa fue una fuga de información al mirar tokens futuros
  • La versión reducida final de Llama tiene 4 bloques y unos 2.37 millones de parámetros, reduce la pérdida de validación hasta cerca de 1.0 y también requiere revisar el flujo de gradients y el schedule de learning rate

Empezar pequeño y ganar confianza iterativamente

  • La clave para implementar un paper es empezar con un modelo pequeño, cambiar los componentes uno por uno y repetir entrenamiento y evaluación en cada cambio
  • Primero se preparan funciones auxiliares para verificar cuantitativamente el modelo
    • División de datos
    • Loop de entrenamiento
    • Visualización de pérdida
    • Evaluación de pérdida de validación
  • En vez de trasladar todos los componentes del paper de una sola vez, también se prepara una función de evaluación cualitativa para ver los resultados generados con un modelo simple, rápido y ya conocido
  • Las capas de tensores se verifican con .shape, assert y plt.imshow; antes de optimizar multiplicaciones de matrices desde el inicio, se comprueban a mano los resultados esperados y luego se eficientiza con funciones de torch
  • Hay que probar cambiando batch size, longitud de secuencia y dimensión de embedding; el código que solo funciona con un tamaño puede romperse en tiempo de inferencia

Dataset y configuración básica

  • El objetivo de implementación es una versión muy reducida de Llama de Meta AI, y los datos de entrenamiento son TinyShakespeare
  • Llama se entrena con 1.4T tokens, pero aquí se usa TinyShakespeare, de alrededor de 1.11 millones de caracteres
  • El Llama original usa un tokenizer byte-pair encoding de SentencePiece, pero esta implementación usa un tokenizer simple a nivel de caracteres
    • el vocabulary size es 65
    • como el dataset es pequeño, no se optimiza por separado la forma de almacenarlo en memoria
  • Con el diccionario MASTER_CONFIG se administran configuraciones del modelo como vocab_size, batch_size, context_window y d_model
    • El objetivo es reducir constantes y magic numbers, y hacer el código más legible
  • La función get_batches divide los datos en train 80%, val 10% y test 10%, y desde un punto inicial aleatorio genera la entrada x y la etiqueta y de un carácter posterior

Verificar compilación y entrenamiento con un modelo básico

  • El primer modelo es SimpleBrokenModel, compuesto por embeddings y una red feed-forward simple
    • nn.Embedding
    • Linear
    • ReLU
    • Linear
  • En la implementación de un paper, decir que un modelo “funciona” implica cumplir ambas condiciones
    • Compila: los shapes de tensores encajan entre capas
    • Entrena: la pérdida realmente baja
  • La función evaluate_loss samplea 10 batches en los splits train y val, y calcula la pérdida promedio
  • Tras entrenar 1000 epochs, SimpleBrokenModel tenía una pérdida de validación de alrededor de 3.94, casi sin bajar desde la cross-entropy inicial de 4.17
  • La causa fue pasarle a F.cross_entropy valores ya procesados con softmax
    • F.cross_entropy de PyTorch recibe directamente logits no normalizados
    • SimpleModel, tras quitar softmax, reduce la pérdida de validación hasta alrededor de 2.51
  • Luego se agrega la función generate para revisar directamente los caracteres creados por el modelo, y el modelo básico queda en un estado imperfecto, pero con pérdida de validación descendente

Componente 1 de Llama: RMSNorm

  • Comparado con el Transformer original, Llama usa tres modificaciones principales de arquitectura
    • RMSNorm pre-normalization
    • Rotary embeddings
    • SwiGLU activation function
  • El Transformer original usa BatchNormalization, pero Llama usa RMSNorm, que escala el vector por la variance sin centering
  • Mientras el Transformer original aplica normalization a la salida de la attention layer, en un esquema de post-normalization, Llama usa pre-normalization, aplicándola primero a la entrada
  • La implementación de RMSNorm asume un shape de entrada (batch, seq_len, d_model)
  • El resultado de RMSNorm se prueba con la propiedad de que la layer norm se vuelve la raíz cuadrada del número de elementos de la capa
    • assert
    • row-wise comparison
    • torch.allclose
  • SimpleModel_RMS, que agrega RMSNorm al modelo básico, reduce ligeramente la pérdida de validación hasta alrededor de 2.5015

Componente 2 de Llama: RoPE y causal mask

  • RoPE es un método de positional encoding para Transformers, que expresa la posición de los tokens como una rotación del embedding
  • get_rotary_matrix genera matrices de rotación por posición para el context window y la embedding dimension
  • La implementación de RoPE se prueba con la siguiente propiedad
    • El producto interno de dos vectores rotados en las posiciones m y n debe coincidir con la rotación de posición relativa n-m
  • RoPEAttentionHead crea w_q, w_k y w_v, aplica la rotación RoPE a query y key, y luego usa F.scaled_dot_product_attention
  • Hay que cuidar la diferencia de shapes de tensores entre entrenamiento e inferencia
    • En entrenamiento, muchas veces coinciden con la configuración, por ejemplo (config['batch_size'], config['context_window'], config['d_model'])
    • En inferencia, puede procesarse un único ejemplo como (1, 1, config['d_model'])
    • Dentro de forward, hay que indexar con base en el shape obtenido desde la entrada, no en los valores de configuración del modelo
  • El modelo que agrega RoPE multi-head attention sin causal mask reduce drásticamente la pérdida de validación hasta 0.1623, pero los resultados generados son malos, como OOOO... o IIII...
  • Al revisar el attention map, todas las posiciones referenciaban a todas las posiciones, y en la predicción del siguiente token ocurría una fuga de información al mirar tokens futuros
  • Al cambiar a RoPEMaskedAttentionHead aplicando is_causal=True en F.scaled_dot_product_attention, la attention upper triangular correspondiente al futuro queda casi en 0
  • Tras aplicar la causal mask, la pérdida de validación queda en 2.0815, y al entrenar por más tiempo baja hasta 1.8985

Componente 3 de Llama: SwiGLU y apilar bloques

  • Llama reemplaza la no linealidad ReLU por la SwiGLU activation function
  • La implementación de SwiGLU es una Swish-gated linear unit, y usa dos transformaciones lineales y un parámetro beta learnable
  • RopeModel, con SwiGLU en la parte feed-forward, tiene 592,706 parámetros y una pérdida de validación de alrededor de 1.8963
  • Luego se crea LlamaBlock para agrupar la siguiente configuración en un bloque
    • RMSNorm pre-normalization
    • masked RoPE multi-head attention
    • residual connection
    • RMSNorm pre-normalization
    • SwiGLU feed-forward
    • residual connection
  • El modelo final Llama se configura con n_layers=4 y apila 4 LlamaBlock con nn.Sequential basado en OrderedDict
  • El modelo final tiene 2,370,246 parámetros, y los resultados de entrenamiento son los siguientes
    • Pérdida de validación de 1.5532 tras el entrenamiento inicial de 4 layers
    • Pérdida de validación de 1.1479 tras seguir entrenando durante 10,000 epochs
    • Pérdida de validación de 0.9997 tras entrenamiento adicional
    • La pérdida de un batch del split test es 1.2358

Resultados de generación y checklist de debugging

  • El modelo final crea nombres, saltos de línea y fragmentos de palabras similares al formato de Shakespeare, pero la calidad real de las oraciones es limitada
  • La pérdida de cross-entropy puede interpretarse intuitivamente desde la perspectiva de selección de tokens
    • La pérdida inicial de 4.17 se acerca a una elección aleatoria con vocabulary size 65
    • Una pérdida de 1.08 se interpreta como un nivel equivalente a elegir aleatoriamente entre unos 2.9 tokens
  • El flujo de gradients se revisa con la función show_grads
    • Calcula la proporción de gradients con valor absoluto pequeño en cada parámetro
    • Si los gradients de la mayoría de los parámetros no están cerca de 0, el flujo está en buen estado
  • El Llama original usa un learning schedule de Cosine Annealing, pero en esta implementación los resultados experimentales fueron peores
  • En el experimento con Cosine Annealing, incluso con una tolerance muy baja, el attention bias casi no recibía señal; como el motivo no está claro, en una implementación real es más seguro empezar de forma simple

1 comentarios

 
GN⁺ 2023-08-10
Opiniones en Hacker News
  • Parece haber un bug en la implementación de SwiGLU: en el paper de referencia, beta de la feed-forward network no es un valor entrenable sino una constante, y se define como FFnSwiGLU = Swish1...
    Según la ecuación 6 de https://arxiv.org/pdf/2002.05202.pdf
    En la implementación oficial de llama también se eliminó la beta constante: https://github.com/facebookresearch/llama/blob/main/llama/mo...
    Si se ven las líneas "feedforward.1.beta', 0.0" del log del blog, durante el entrenamiento beta degeneró a 0, cuando originalmente debería ser la constante 1

    • Esto muestra lo difícil que es implementar correctamente una red neuronal Transformer. Uno puede equivocarse en varias etapas, y por lo general solo se manifiesta como “un rendimiento un poco peor que el original”, así que es difícil saberlo con certeza
      Muchas veces la red se adapta a los cambios, sean intencionales o no, y después del entrenamiento varias variantes de arquitectura también se comportan de forma similar, por lo que a veces es ambiguo si realmente debe coincidir con la original
      Una forma de encontrar estos errores es hacer coincidir exactamente los valores de salida con una implementación de referencia. Incluso con pesos aleatorios, como en los modelos tiny-random de HuggingFace, la salida debería ser exactamente la misma; si no lo es, es señal de bug
      Sin embargo, este método funciona bien sobre todo para bugs que aparecen durante la inferencia; los problemas que solo ocurren en el procesamiento de datos, el optimizador o durante el entrenamiento son más difíciles de detectar
    • En los Transformers, creo que los valores de sesgo en general no encajan muy bien
      Personalmente pienso que es por su naturaleza autorregresiva y similar a una ODE, pero no estoy tan seguro como para afirmarlo
  • El trabajo es excelente, pero los primeros SimpleBrokenModel y SimpleModel tienen bastantes operaciones desperdiciadas. El orden es embedding 65 -> 128, linear 128 -> 128, ReLU, linear 128 -> 65; como no hay no linealidad entre las dos primeras capas y ambas son lineales, la segunda capa lineal en realidad no sirve de mucho
    Este modelo termina siendo equivalente a un MLP clásico de una sola capa oculta y, en términos de FLOPS, desperdicia 128*128=16k operaciones de un total de 128*128+65*128=24k

    • Parece que no soy el único que todavía está aprendiendo sobre no linealidades. Me pregunto si la mejor corrección aquí sería poner ReLU o SwiGLU entre el embedding y la primera capa lineal, o simplemente eliminar la capa lineal
      La capa de embedding es una estructura especial que convierte índices de tokens en vectores de embedding, así que no creo que se pueda eliminar
  • En general muestra muy bien los principios básicos. En especial me gusta la frase “usa .shape religiosamente. assert y plt.imshow son tus amigos”, y las precondiciones y poscondiciones de shape siempre deberían verificarse con assert
    También me pregunto si bear o typeguard soportan este tipo de validaciones mediante decoradores
    Pero en la parte de “elige un modelo pequeño, simple y rápido, y crea helpers para evaluarlo cualitativamente”, sospecho que en realidad se refiere a una evaluación cuantitativa. Así se obtiene una línea base numérica para comparar con técnicas más avanzadas
    El consejo de implementar uno por uno los componentes del paper también debería ser más preciso. Los papers normalmente prueban varios cambios a la vez y luego muestran la contribución de cada elemento mediante estudios de ablación; por eso creo que es mejor empezar por los cambios centrales de arquitectura y evaluar cada cambio atómico, respetando las dependencias, en el orden de mayor impacto observado en las ablaciones

    • En lugar de bear o typeguard, gracias a https://peps.python.org/pep-0646/, parte de esto se puede empujar directamente a las anotaciones de tipo de Python
      Por ejemplo, se puede expresar el shape por eje en el tipo con algo como ndarray[float, Dim1, *Shape], y sobrecargar el shape de retorno según el valor de axis
    • No conozco bien PyTorch, pero la última vez que revisé no era así; Jax sí soporta validaciones básicas de shape de matrices en runtime mediante bear / typeguard
      Aun así, parece difícil que Python sea tan bueno como Julia. El sistema de tipos de Julia permite garantizar con mucha más facilidad que los tamaños de las matrices coincidan
  • Me pregunto cuál es el principio para usar SwiGLU en lugar de ReLU. No sé si los autores simplemente probaron todas las funciones no lineales posibles o si hay una razón más profunda

    • Como pasa con mucha investigación, si no hay una explicación clara respaldada por un estudio riguroso, probablemente hicieron una búsqueda hill-climbing aleatoria de cambios de una línea que se veían interesantes, y se detuvieron cuando llegó el momento de escribir el paper y hacer los estudios de ablación
  • Como bearblog está recibiendo un DDoS, dejo el repositorio: https://github.com/bkitano/llama-from-scratch

  • Desde la perspectiva de alguien que está aprendiendo IA, hice un resumen simple de los términos que aparecen en el texto. Un token es un identificador entero que representa un fragmento de texto y, en los LLM, se agrupan fragmentos de caracteres de uso frecuente dentro de un tamaño de vocabulario limitado
    La función de pérdida es un valor que mide la diferencia entre la predicción y la respuesta correcta, y mientras más bajo sea, mejor. PyTorch es una biblioteca para trabajar con tensores y redes neuronales, y un tensor es un arreglo multidimensional de números que incluye escalares, vectores y matrices
    Una red neuronal es una estructura de conexiones de neuronas con pesos y sesgos, y una capa lineal es una estructura simple en la que todas las entradas y salidas están conectadas. ReLU es una función de activación como Math.max(0, x): si solo apilas capas lineales, al final equivale a una sola función lineal, así que se introduce no linealidad para aumentar la capacidad de aprendizaje
    El gradiente es una cantidad de cambio numérico que se calcula durante el entrenamiento para hacer que el modelo sea más preciso, y la normalización por lotes es un método que ayuda al entrenamiento ajustando los números que fluyen. La codificación posicional le indica al modelo las posiciones relativas de los tokens mediante vectores
    El operador @ de Python es un alias de __matmul__ y se usa para multiplicación de matrices. Una época es entrenar una vez sobre todo el dataset, y un lote es la cantidad de datos que se ingresan de una sola vez antes de actualizar los parámetros
    La atención es el componente central que hace funcionar a los LLM: procesa los tokens de entrada en paralelo para crear tensores intermedios y luego los usa para generar tokens de salida

    • Fuera del área, puede que no se sepa qué significa “Karpathy”. Si se presenta a Andrej Karpathy con contexto, como “comunicador científico e investigador”, queda más claro que la idea es consultar sus textos o videos
    • Más que decir que un token es simplemente un identificador entero de un fragmento de texto, para principiantes es más preciso verlo como un fragmento de palabra lo bastante común como para ser útil por sí mismo
      Por ejemplo, writ, que aparece en común en writing, written y writer, puede ser un token, y writer puede tokenizarse como writ y er
      El embedding es la etapa que convierte esos tokens en representaciones numéricas propias
    • Si compones funciones lineales, el resultado vuelve a ser una función lineal. Por eso, si todo es lineal, aunque apiles muchas capas, todas salvo una son un desperdicio; para evitarlo hace falta no linealidad
    • Además de la serie de videos de Karpathy y el accompanying repo, me pregunto si hay otros materiales o libros que hayan sido especialmente útiles en el proceso de aprendizaje
    • Me da curiosidad qué hace exactamente la normalización por lotes y cómo ayuda
  • Si existe una implementación previa del modelo y checkpoints, la forma más efectiva de comprobar si tu implementación es correcta es cargar ese checkpoint y comparar los valores de salida
    Si la salida no coincide, por lo general significa que algún detalle de implementación está mal, y puedes seguir sistemáticamente cada capa para encontrar la diferencia real. En el camino, incluso podrías descubrir algo extraño en la implementación existente
    Esto se refiere al modelo en sí; el entrenamiento es otro eje aparte. Aun así, si ajustaste los hiperparámetros de forma más o menos similar, cuando la implementación del modelo es correcta, en general las cosas salen bien

  • Tanto la forma de leer papers como el contenido de ese paper están muy buenos, y también recomiendo la serie Makemore de Karpathy

  • Los consejos resumidos son muy buenos, y creo que el de hacer assert de los shapes de los tensores aplica a cualquier biblioteca general de álgebra lineal. Al escribir código complejo de álgebra lineal, es muy importante avanzar en pasos pequeños y programar de forma defensiva
    Programar álgebra lineal en lenguajes mainstream es terrible porque no hay verificación de shapes en tiempo de compilación. El shape de un tensor debería formar parte del tipo, y si intentas multiplicar 3x4 por 3x4 sin transponer, ni siquiera debería compilar
    Es realmente lo peor correr un cálculo largo y que luego falle por una operación con dimensiones incompatibles
    También creo que los tensores de PyTorch deberían tener el dispositivo tipado estáticamente. Hoy, si intentas multiplicar un tensor en memoria de CPU por uno en memoria de GPU, aparece un error en tiempo de ejecución