2 puntos por GN⁺ 2024-09-24 | 1 comentarios | Compartir por WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (nuestro recorrido)

Introducción

  • A medida que los modelos de código abierto crecen, aumenta la necesidad de una infraestructura potente capaz de manejar entrenamiento de IA a gran escala
  • Felafax ajustó finamente el modelo LLaMA 3.1 405B en GPU de AMD, demostrando la eficiencia del hardware de AMD
  • Todo el trabajo fue publicado como código abierto en GitHub
  • Las GPU AMD MI300X ofrecen alto rendimiento en comparación con el hardware de IA de NVIDIA
  • El proyecto fue posible gracias al apoyo de TensorWave

Qué es JAX y por qué lo eligieron

  • JAX es una potente librería de machine learning que combina una API similar a NumPy, diferenciación automática y el compilador XLA de Google
  • Ofrece una excelente API para paralelismo de modelos, por lo que es ideal para entrenar modelos de gran escala

Ventajas de JAX

  • Funciones puras: JAX fomenta escribir funciones puras, lo que facilita estructurar, depurar y leer el código
  • Paralelismo avanzado: la flexible API de JIT de JAX soporta paralelismo avanzado de datos y de modelos, algo esencial para el entrenamiento a gran escala
  • Base de código limpia: la filosofía de diseño de JAX promueve escribir código portable entre distintas plataformas de hardware

Por qué JAX destaca en hardware que no es de NVIDIA

  • Enfoque independiente del hardware: JAX aprovecha el compilador XLA para compilar los cálculos a una representación intermedia independiente del hardware
  • Optimización independiente de la plataforma: el compilador XLA realiza optimizaciones sin depender del hardware
  • Portabilidad sencilla: al usar JAX, los cambios de código al pasar de NVIDIA a AMD son mínimos

Configuración de JAX en GPU de AMD

  • Se descarga la imagen de Docker, se inicia el contenedor y luego se verifica la instalación
  • Se entrenó el modelo LLaMA 405B usando 8 GPU AMD MI300x

Entrenamiento de LLaMA 405B: rendimiento y escalabilidad

  • Se entrenó el modelo LLaMA 405B en GPU de AMD usando JAX
  • Mediante ajuste fino con LoRA, se ajustaron los pesos del modelo y los parámetros de LoRA con precisión bfloat16
  • Tamaño del modelo: ocupa aproximadamente 800 GB de VRAM
  • Pesos de LoRA y estado del optimizador: ocupan aproximadamente 400 GB de VRAM
  • Uso total de VRAM: aproximadamente 1200 GB
  • Velocidad de entrenamiento: alrededor de 35 tokens por segundo
  • Eficiencia de memoria: se mantuvo en torno al 70%
  • Escalabilidad: usando JAX, escaló de forma casi lineal en 8 GPU

Nuestra configuración de entrenamiento

  • Se convirtió LLaMA 3.1 de PyTorch a JAX
  • Se distribuyó de forma eficiente mediante carga del modelo y sharding de parámetros

Sharding de parámetros en JAX

  • Se usó la función de device mesh de JAX para distribuir eficientemente el modelo en 8 GPU AMD
  • Se definieron reglas de sharding de parámetros para fragmentar cada dimensión de los tensores según los ejes de la malla

Implementación del entrenamiento con LoRA

  • LoRA reduce la cantidad de parámetros entrenables al descomponer las actualizaciones de pesos en matrices de bajo rango
  • Se implementó una capa LoRADense que incluye parámetros de LoRA
  • Los parámetros de LoRA se distribuyeron de forma eficiente para optimizar el uso de memoria y la eficiencia de cómputo

Conclusión

  • La experiencia de ajustar finamente el modelo LLaMA 3.1 405B con GPU de AMD y JAX fue muy positiva
  • Aprovechando las potentes capacidades de paralelismo de JAX y su enfoque independiente del hardware, el modelo se distribuyó de forma eficiente
  • Esto demuestra que las GPU de AMD son una alternativa sólida para el entrenamiento de IA a gran escala
  • Se puede revisar el código completo en el repositorio de GitHub y ejecutarlo directamente

Resumen de GN⁺

  • Este artículo explica cómo entrenar eficientemente modelos de IA de gran escala usando GPU de AMD y JAX
  • Destaca que el hardware de AMD es una alternativa rentable frente a NVIDIA
  • El enfoque independiente del hardware de JAX mejora la portabilidad del código y facilita el mantenimiento
  • Ofrece información útil y código práctico para quienes estén interesados en el entrenamiento de modelos a gran escala
  • Proyectos con capacidades similares incluyen CUDA y PyTorch de NVIDIA

1 comentarios

 
GN⁺ 2024-09-24
Opiniones de Hacker News
  • Hace poco hicimos fine-tuning del modelo llama3.1 405B en 8 GPU AMD MI300x usando JAX en lugar de PyTorch.
    Gracias a la API avanzada de sharding de JAX obtuvimos buen rendimiento, y resumimos en el blog la técnica de sharding que usamos. También publicamos el código: https://github.com/felafax/felafax
    Somos una startup pequeña que construye infraestructura de IA para fine-tuning y serving de LLM en hardware que no es NVIDIA (TPU, AMD, Trainium).
    Muchas empresas intentan correr PyTorch en GPU AMD, pero PyTorch está profundamente entrelazado con el ecosistema de NVIDIA, con cosas como torch.cuda o scaled_dot_product_attention, así que creemos que hace falta bastante “des-NVIDIAficación”.
    Creemos que JAX encaja mejor con hardware que no es NVIDIA, porque el código del modelo se compila a un grafo HLO independiente del hardware, luego el compilador XLA lo optimiza y después aplica optimizaciones específicas para cada hardware. El mismo código JAX de LLaMA3 funcionó sin cambios en Google TPU y en GPU AMD.
    La estrategia de la empresa es portar primero los modelos a JAX y luego aprovechar el framework JAX y los kernels XLA para extraer el máximo rendimiento en backends que no sean NVIDIA. Por eso primero pasamos Llama 3.1 de PyTorch a JAX, y el mismo modelo JAX funciona bien en TPU y GPU AMD.

    • No tuvimos mayores problemas para correr PyTorch en GPU AMD sin cambios en el código CUDA. También vale la pena ver el blog de MosaicML: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • Me da curiosidad cómo están validando la precisión del port a JAX de Llama 3.1.
      Personalmente, la razón principal por la que uso PyTorch es que el modelo original fue creado en PyTorch. Aunque la lógica parezca la misma entre distintas versiones del modelo, a escalas enormes de datos errores minúsculos de punto flotante pueden acumularse y producir deriva del modelo.
      Depurar estas discrepancias de precisión en modelos grandes fue casi más doloroso que el décimo círculo del infierno.
    • Me pregunto si JAX tiene implementaciones propias de multiplicación de matrices o FlashAttention, o si usa implementaciones de ROCm como PyTorch. Por ejemplo, cosas como hipblaslt o Composable Kernel FA.
      No conozco muy bien JAX, pero creo que una buena parte de la razón por la que el rendimiento de entrenamiento de PyTorch en MI300x es desastroso es la lentitud de las bibliotecas ROCm que usa internamente.
    • Me pregunto si también funciona en tarjetas de consumo como la 7900 XTX.
      Y con “funciona” no me refiero a pasar dos semanas peleando con los drivers y luego quedar en un estado en el que nunca más puedas actualizar el servidor.
    • Si se trata de una migración, me gustaría saber si tienen números reales comparados con la versión en PyTorch del mismo modelo. La tabla comparativa del artículo parece más bien de aspectos técnicos.
      También me dan curiosidad los problemas técnicos que encontraron.
  • Para ser claros, este rendimiento es bastante malo. Probablemente se deba a que no lograron que la compilación funcionara correctamente.
    En el modelo 405B obtienen 35 tokens/s, lo que equivale a unos 85 teraflops. Ocho GPU MI300x están en el orden de 10,4 petaflops, así que el MFU es de alrededor de 0,8%.
    Es una cifra entre 40 y 50 veces menor que un rendimiento de entrenamiento decente de 30 a 40% de MFU, así que AMD seguramente esperaría que el cuello de botella sea el stack de software.

    • Yo también quería preguntar exactamente eso.
      En la página de GitHub dicen que “se puede ajustar LLaMa3.1 en Google Cloud TPU con un 30% menos de costo”, pero no mencionan el rendimiento.
  • Excelente trabajo. Hace alrededor de un año estuve probando un poco GPU AMD y soporte ROCm, y quedó claro que AMD todavía tiene mucho camino por recorrer para alcanzar a Nvidia.
    El enfoque de elegir JAX es interesante; me pregunto qué dificultades encontraron al alejarse de PyTorch, que es casi la biblioteca estándar de machine learning.

    • Hace unas semanas publicamos un Show HN explicando nuestro recorrido: https://news.ycombinator.com/item?id=41512142
      Al principio el objetivo era hacer fine-tuning de LLaMA 3 en TPU, pero PyTorch XLA era tosco, así que decidimos reescribir el modelo en JAX.
      Como dije antes, vemos a JAX como una mejor plataforma para GPU que no son NVIDIA, y queremos construir infraestructura para GPU no NVIDIA sobre JAX+openXLA.
    • No he logrado hacer funcionar AMD ROCm en mi sistema Debian 12, así que parece que Ollama usa la CPU en vez de la GPU. Parece que todavía falta mucho.
  • Buen trabajo. El fin de semana pasado yo también estuve probando la parte de inferencia de 405B [0].
    No estoy tan convencido de que torch.cuda sea tan malo, porque PyTorch para AMD lo traduce por debajo. Me parece más un problema de nombre que un problema fundamental.
    De hecho, traer el contenedor rocm:pytorch es tan fácil como traer el contenedor rocm:jax.
    No hay muchos números publicados; me da curiosidad qué MFU obtuvieron.
    [0] https://x.com/HotAisle/status/1837580046732874026

    • Bien.
      Hay que calcular el MFU. Los detalles de GPU y VRAM se pueden ver en el repositorio: https://dub.sh/amd-405b-res
      El próximo fin de semana planeo volver a intentar la ejecución de entrenamiento, compilar con JIT todo el paso de entrenamiento y calcular el MFU entonces.
  • Cuando lo medimos en ZML, MI300X fue 30% más rápida que H100. Son chips excelentes.

  • Me pregunto si hay algún proveedor de nube donde se pueda alquilar un host 8xAMD MI300.
    En el trabajo usamos mucho AWS y quería probar alguna vez una GPU AMD.

    • Como referencia, nuestra empresa alquila 8xMI300x, así que puedes contactarnos.
    • Oracle las ofrece. Es muy probable que otros también sigan, pero creo que tendrá más sentido tratar con proveedores más pequeños.
  • ¿Dónde están los datos de rendimiento?

    • Agregué datos de uso de GPU y VRAM al repositorio de GitHub: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      Por el código y las restricciones de VRAM, no pudimos ejecutar la versión compilada con JIT del modelo 405B. Hay que investigar más esa parte.
      Toda la ejecución de entrenamiento se hizo en modo eager de JAX, así que hay bastante margen de mejora de rendimiento.
      Incluso en modo eager, el uso de GPU fue en general de alrededor de 30 a 40%, lo cual está bastante bien. Creo que con JIT el uso de GPU podría subir fácilmente a 50 o 60%.
  • Si es posible, sería interesante explorar cómo superar las restricciones de memoria y ejecutar la versión compilada con JIT. Podría llevar a mejoras adicionales de rendimiento.

    • De acuerdo. Todavía queda mucho rendimiento por exprimir.
      Necesitamos un paso de entrenamiento compilado con JIT, carga de datos y sharding más optimizados, acumulación de gradientes y activation checkpointing.
      Seguiremos construyendo e implementaremos todas las mejoras, y pronto volveremos a publicar un blog.
  • Me pregunto si AMD está siquiera un poco más cerca de extraer valor de esto mediante pedidos masivos de GPU y escasez de suministro.
    Mi impresión es más bien que “no”.

    • Entiendo el sarcasmo. Pero si en este momento no quieres dejar todo el hardware y software de IA en manos de un único proveedor, hay que empezar a moverse hacia alternativas.
      El rival tiene una ventaja inicial enorme, y claramente hay mucho trabajo por hacer en el lado del software. Llevará tiempo.
  • ¿Por qué Obsidian, la app de notas, está haciendo esto?

    • No es eso. Esta empresa está usando Obsidian Publish para publicar la documentación.