Cómo escalar tu modelo: una perspectiva de sistemas sobre los LLM en TPU
(jax-ml.github.io)- Optimizar el rendimiento de deep learning a gran escala puede parecer una “alquimia”, pero en realidad es posible mejorar la eficiencia del modelo con principios simples y comprensibles.
- Desde un solo acelerador hasta decenas de miles de aceleradores, se aplican principios relativamente simples en todos los casos, y entenderlos permite realizar tareas útiles como las siguientes:
- Estimar aproximadamente qué tan cerca está cada parte del modelo de su valor óptimo teórico
- Tener una base para elegir entre varias técnicas de paralelización a distintas escalas
- Estimar el costo y el tiempo necesarios para entrenar y ejecutar modelos Transformer grandes
- Diseñar algoritmos que aprovechen las características de un hardware específico
- Diseñar hardware con una comprensión clara de los límites del rendimiento de los algoritmos actuales
- Conocimientos previos necesarios
- Se requiere una comprensión básica de los LLM y de la arquitectura Transformer
- No es indispensable entender cómo operan los sistemas a gran escala
- Ayuda tener conocimientos básicos de entrenamiento de LLM y experiencia usando JAX
- Se recomienda consultar una publicación de blog sobre la arquitectura Transformer y diapositivas sobre escalado de LLM en JAX
- Objetivos
- Desarrollar la capacidad de estimar cómo conviene paralelizar un modelo en un hardware dado
- Desarrollar la capacidad de calcular de forma aproximada el tiempo y el costo del entrenamiento y la inferencia
Por qué debería importarte
- Hasta hace 3 o 4 años, la mayoría de los investigadores de ML no necesitaban conocer a fondo este tipo de optimización a gran escala.
- Hoy, incluso los modelos “pequeños” operan cerca de los límites del hardware, así que entender formas eficientes de trabajo a gran escala se volvió esencial
- La historia del ML puede verse como una evolución entrelazada de innovación en sistemas y mejoras de software
- Como los modelos Transformer recientes aprovechan el hardware hasta sus límites, si no se entiende la eficiencia del modelo, es muy probable que una nueva arquitectura o línea de investigación fracase en su aplicación real
- Incluso si se obtiene una mejora del 20% en benchmarks, si la eficiencia de hardware cae un 20%, al final su utilidad práctica será baja
- El objetivo central del escalado de modelos es lograr que el throughput aumente linealmente al incrementar la cantidad de chips (aceleradores).
- A esto se le llama "strong scaling"
- Agregar chips reduce el tiempo de cómputo, pero genera costos de comunicación entre chips
- Si la comunicación tarda más que el cómputo, se entra en un estado "communication bound" y el strong scaling deja de ser posible
- Si se entiende suficientemente bien el hardware como para predecir dónde aparecerán estos cuellos de botella, se puede diseñar o reconfigurar el modelo para evitarlos
- El objetivo de este libro es explicar cómo funciona el hardware TPU (y GPU) y cómo la arquitectura Transformer ha evolucionado para funcionar bien en el hardware actual
- Se espera que sea útil tanto para investigadores que diseñan nuevas arquitecturas como para ingenieros que buscan ejecutar rápidamente los LLM de la generación actual
Panorama general
- Este texto está organizado de la siguiente manera.
- La sección 1 explica, mediante análisis roofline, los factores que determinan el límite de rendimiento del modelo (comunicación, cómputo y memoria).
- Las secciones 2 y 3 abordan la estructura interna de TPU y GPU, así como la forma en que se conectan entre chips.
- Con esto se responden preguntas como:
- ¿Qué tan rápido puede ejecutarse teóricamente una multiplicación de matrices de cierto tamaño?
- ¿En qué punto el cómputo queda limitado por el ancho de banda de memoria o de comunicación?
- ¿Cómo está conectado un clúster de TPU y cuánto tiempo toma aproximadamente mover datos de un chip a otro?
- ¿Cómo se pueden multiplicar matrices distribuidas de forma eficiente?
- Con esto se responden preguntas como:
- La sección 4 trata en detalle las fórmulas de la arquitectura Transformer (tamaños de matrices, número de parámetros, FLOPs).
- La sección 5 y la sección 7 son el núcleo, e introducen varias formas de paralelizar modelos en múltiples chips.
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- También cubren técnicas de ahorro de memoria como ZeRO, Rematerialisation, Host offload y Gradient accumulation
- La sección 6 y la sección 8 presentan, con el ejemplo del entrenamiento e inferencia de un modelo LLaMA-3 en TPU, los costos, tiempos y configuraciones reales.
- Por último, la sección 9 y la sección 10 cubren métodos prácticos para perfilar, depurar y aplicar procesamiento paralelo a modelos en JAX.
Detalles: resumen de las principales secciones del libro
-
Parte 1: Preliminaries
-
Sección 1: Introducción a un análisis Roofline simple
- Los tres factores que restringen un algoritmo: cómputo, comunicación y memoria
- A partir de eso, aprender cómo estimar el límite superior de la velocidad de cómputo
-
Sección 2: Una forma de ver las TPU
- Cómo realizan cómputo las TPU
- Qué es una estructura de systolic array
- Comprensión básica de cómo las TPU ofrecen ancho de banda de memoria y comunicación
-
Sección 3: Matrices distribuidas y multiplicación distribuida
- La técnica de almacenar los parámetros del modelo repartidos entre varios chips (Sharding)
- Cómo tratar la comunicación y los cuellos de botella que surgen al operar con matrices distribuidas
-
-
Parte 2: Transformers
-
Sección 4: Resumen de las fórmulas de Transformer necesarias
- Qué forma toman concretamente las multiplicaciones de matrices en Transformer
- Cómo calcular el número de parámetros, los FLOPs, el tamaño de la caché KV y otros elementos
- Entender cuántas operaciones requiere Attention en comparación con los bloques Feed-Forward
-
Sección 5: Estrategias de paralelización para entrenar Transformers
- Introducción a técnicas como Data parallel, Tensor parallel, Pipeline parallel y Expert parallel
- Opciones de ahorro de memoria como ZeRO(FSDP), Rematerialisation, Gradient accumulation y Host offload
- Establecer conceptos para configurar la paralelización según el tamaño del modelo y la cantidad de chips
-
Sección 6: Aplicación al entrenamiento de LLaMA 3 en TPU
- Estimación del tiempo y costo requeridos asumiendo que se entrena un modelo LLaMA 3 en un entorno TPU real
- Presenta ejemplos concretos sobre batch size, forma de paralelización, uso de memoria, etc.
-
Sección 7: Todo sobre la inferencia de Transformer
- En inferencia aparece un nuevo factor importante: la latencia
- Uso de memoria y problemas de comunicación causados por la caché KV y otros elementos
- Discusión sobre cómo asignar y conectar varios chips para servir el modelo
-
Sección 8: Aplicación al serving de LLaMA 3 en TPU
- Análisis aproximado de costos, latencia y trade-offs de throughput asumiendo serving de LLaMA 3 en TPU v5e
-
-
Parte 3: Practical Tutorials
-
Sección 9: Cómo perfilar código en TPU
- Comprender el stack JAX+XLA
- Identificar problemas reales de degradación de rendimiento y sus soluciones
- Cómo usar el profiler de JAX/TensorBoard
-
Sección 10: Programar TPU con JAX
- Cómo usar las API(primitives) de paralelización de JAX
- Aprender conceptos de cómputo paralelo mediante ejemplos y ejercicios
-
Sección 11: Conclusión y recursos adicionales
- Lecturas adicionales sobre TPU y LLM
- Cierre breve del contenido general y mención de perspectivas futuras
-
1 comentarios
Comentarios de Hacker News