2 puntos por GN⁺ 2024-09-24 | 1 comentarios | Compartir por WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (nuestro recorrido)

Introducción

  • A medida que los modelos de código abierto crecen, aumenta la necesidad de una infraestructura potente capaz de manejar entrenamiento de IA a gran escala
  • Felafax ajustó finamente el modelo LLaMA 3.1 405B en GPU de AMD, demostrando la eficiencia del hardware de AMD
  • Todo el trabajo fue publicado como código abierto en GitHub
  • Las GPU AMD MI300X ofrecen alto rendimiento en comparación con el hardware de IA de NVIDIA
  • El proyecto fue posible gracias al apoyo de TensorWave

Qué es JAX y por qué lo eligieron

  • JAX es una potente librería de machine learning que combina una API similar a NumPy, diferenciación automática y el compilador XLA de Google
  • Ofrece una excelente API para paralelismo de modelos, por lo que es ideal para entrenar modelos de gran escala

Ventajas de JAX

  • Funciones puras: JAX fomenta escribir funciones puras, lo que facilita estructurar, depurar y leer el código
  • Paralelismo avanzado: la flexible API de JIT de JAX soporta paralelismo avanzado de datos y de modelos, algo esencial para el entrenamiento a gran escala
  • Base de código limpia: la filosofía de diseño de JAX promueve escribir código portable entre distintas plataformas de hardware

Por qué JAX destaca en hardware que no es de NVIDIA

  • Enfoque independiente del hardware: JAX aprovecha el compilador XLA para compilar los cálculos a una representación intermedia independiente del hardware
  • Optimización independiente de la plataforma: el compilador XLA realiza optimizaciones sin depender del hardware
  • Portabilidad sencilla: al usar JAX, los cambios de código al pasar de NVIDIA a AMD son mínimos

Configuración de JAX en GPU de AMD

  • Se descarga la imagen de Docker, se inicia el contenedor y luego se verifica la instalación
  • Se entrenó el modelo LLaMA 405B usando 8 GPU AMD MI300x

Entrenamiento de LLaMA 405B: rendimiento y escalabilidad

  • Se entrenó el modelo LLaMA 405B en GPU de AMD usando JAX
  • Mediante ajuste fino con LoRA, se ajustaron los pesos del modelo y los parámetros de LoRA con precisión bfloat16
  • Tamaño del modelo: ocupa aproximadamente 800 GB de VRAM
  • Pesos de LoRA y estado del optimizador: ocupan aproximadamente 400 GB de VRAM
  • Uso total de VRAM: aproximadamente 1200 GB
  • Velocidad de entrenamiento: alrededor de 35 tokens por segundo
  • Eficiencia de memoria: se mantuvo en torno al 70%
  • Escalabilidad: usando JAX, escaló de forma casi lineal en 8 GPU

Nuestra configuración de entrenamiento

  • Se convirtió LLaMA 3.1 de PyTorch a JAX
  • Se distribuyó de forma eficiente mediante carga del modelo y sharding de parámetros

Sharding de parámetros en JAX

  • Se usó la función de device mesh de JAX para distribuir eficientemente el modelo en 8 GPU AMD
  • Se definieron reglas de sharding de parámetros para fragmentar cada dimensión de los tensores según los ejes de la malla

Implementación del entrenamiento con LoRA

  • LoRA reduce la cantidad de parámetros entrenables al descomponer las actualizaciones de pesos en matrices de bajo rango
  • Se implementó una capa LoRADense que incluye parámetros de LoRA
  • Los parámetros de LoRA se distribuyeron de forma eficiente para optimizar el uso de memoria y la eficiencia de cómputo

Conclusión

  • La experiencia de ajustar finamente el modelo LLaMA 3.1 405B con GPU de AMD y JAX fue muy positiva
  • Aprovechando las potentes capacidades de paralelismo de JAX y su enfoque independiente del hardware, el modelo se distribuyó de forma eficiente
  • Esto demuestra que las GPU de AMD son una alternativa sólida para el entrenamiento de IA a gran escala
  • Se puede revisar el código completo en el repositorio de GitHub y ejecutarlo directamente

Resumen de GN⁺

  • Este artículo explica cómo entrenar eficientemente modelos de IA de gran escala usando GPU de AMD y JAX
  • Destaca que el hardware de AMD es una alternativa rentable frente a NVIDIA
  • El enfoque independiente del hardware de JAX mejora la portabilidad del código y facilita el mantenimiento
  • Ofrece información útil y código práctico para quienes estén interesados en el entrenamiento de modelos a gran escala
  • Proyectos con capacidades similares incluyen CUDA y PyTorch de NVIDIA

1 comentarios

 
GN⁺ 2024-09-24
Opiniones en Hacker News
  • Comparten resultados de haber ajustado finamente el modelo Llama3.1 405B usando JAX en 8 GPU AMD MI300x

    • Lograron un rendimiento sobresaliente gracias a la API avanzada de sharding de JAX
    • Comparten enlaces al post del blog y al código open source: enlace de GitHub
    • Son una startup que construye infraestructura de IA para ajustar finamente y servir LLM en TPU, AMD y Trainium, en lugar de hardware de NVIDIA
    • Consideran que muchas empresas intentan hacer funcionar PyTorch sobre GPU AMD, pero que ese es un camino difícil
    • PyTorch está profundamente ligado al ecosistema de NVIDIA, por lo que hacerlo funcionar en hardware que no sea de NVIDIA requiere muchas modificaciones
    • Creen que JAX es más adecuado para hardware que no sea de NVIDIA
    • En JAX, el código del modelo de ML se compila en un grafo HLO independiente del hardware, y el compilador XLA realiza optimizaciones específicas para cada hardware
    • El mismo código de JAX puede ejecutarse en Google TPU y en GPU AMD sin cambios
    • La estrategia de la empresa es portar los modelos a JAX y aprovechar los kernels de XLA para extraer el máximo rendimiento en backends que no sean de NVIDIA
    • Portaron Llama 3.1 por primera vez de PyTorch a JAX, y ahora ese mismo modelo en JAX funciona bien tanto en TPU como en GPU AMD
    • Les gustaría conocer opiniones sobre la visión y el repositorio
  • Sugieren explorar cómo superar las limitaciones de memoria y ejecutar una versión compilada con JIT

    • Podría aportar mejoras adicionales de rendimiento
  • Comparten experiencia sobre GPU AMD y el soporte de ROCm

    • Hace un año intentaron usar GPU AMD y soporte ROCm, pero sintieron que AMD todavía está lejos de alcanzar a NVIDIA
    • Elegir JAX parece un enfoque interesante, pero se preguntan qué dificultades hubo al alejarse de PyTorch
  • Comparten experiencia experimentando con el modelo 405B desde el lado de inferencia

    • Piensan que torch.cuda no está tan mal
    • Consideran que es solo un tema de nombre, ya que la versión de PyTorch para AMD lo traduce
    • Usar el contenedor rocm:pytorch es tan fácil como usar el contenedor rocm:jax
    • Señalan que no se han publicado muchos datos de rendimiento
    • Tienen curiosidad por las cifras de MFU (utilización del modelo)
  • Preguntan por la ausencia de datos de rendimiento

    • Cuestionan la posibilidad de extraer valor debido a los pedidos masivos de GPU AMD
    • Les queda la impresión de que la respuesta es "no"
  • Dudan de por qué Obsidian (la app para tomar notas) está haciendo esto

    • Al principio pensaron que era una publicación de Obsidian
    • Se preguntan por qué aún no distinguen entre GitHub.com y GitHub.io
  • Piden a @dang que incluya el nombre de usuario en la URL

    • Esta publicación trata de un blog generado por un usuario, no de Obsidian en sí