Felafax BlogTune Llama3 405B on AMD MI300x (nuestro recorrido)
Introducción
- A medida que los modelos de código abierto crecen, aumenta la necesidad de una infraestructura potente capaz de manejar entrenamiento de IA a gran escala
- Felafax ajustó finamente el modelo LLaMA 3.1 405B en GPU de AMD, demostrando la eficiencia del hardware de AMD
- Todo el trabajo fue publicado como código abierto en GitHub
- Las GPU AMD MI300X ofrecen alto rendimiento en comparación con el hardware de IA de NVIDIA
- El proyecto fue posible gracias al apoyo de TensorWave
Qué es JAX y por qué lo eligieron
- JAX es una potente librería de machine learning que combina una API similar a NumPy, diferenciación automática y el compilador XLA de Google
- Ofrece una excelente API para paralelismo de modelos, por lo que es ideal para entrenar modelos de gran escala
Ventajas de JAX
- Funciones puras: JAX fomenta escribir funciones puras, lo que facilita estructurar, depurar y leer el código
- Paralelismo avanzado: la flexible API de JIT de JAX soporta paralelismo avanzado de datos y de modelos, algo esencial para el entrenamiento a gran escala
- Base de código limpia: la filosofía de diseño de JAX promueve escribir código portable entre distintas plataformas de hardware
Por qué JAX destaca en hardware que no es de NVIDIA
- Enfoque independiente del hardware: JAX aprovecha el compilador XLA para compilar los cálculos a una representación intermedia independiente del hardware
- Optimización independiente de la plataforma: el compilador XLA realiza optimizaciones sin depender del hardware
- Portabilidad sencilla: al usar JAX, los cambios de código al pasar de NVIDIA a AMD son mínimos
Configuración de JAX en GPU de AMD
- Se descarga la imagen de Docker, se inicia el contenedor y luego se verifica la instalación
- Se entrenó el modelo LLaMA 405B usando 8 GPU AMD MI300x
Entrenamiento de LLaMA 405B: rendimiento y escalabilidad
- Se entrenó el modelo LLaMA 405B en GPU de AMD usando JAX
- Mediante ajuste fino con LoRA, se ajustaron los pesos del modelo y los parámetros de LoRA con precisión bfloat16
- Tamaño del modelo: ocupa aproximadamente 800 GB de VRAM
- Pesos de LoRA y estado del optimizador: ocupan aproximadamente 400 GB de VRAM
- Uso total de VRAM: aproximadamente 1200 GB
- Velocidad de entrenamiento: alrededor de 35 tokens por segundo
- Eficiencia de memoria: se mantuvo en torno al 70%
- Escalabilidad: usando JAX, escaló de forma casi lineal en 8 GPU
Nuestra configuración de entrenamiento
- Se convirtió LLaMA 3.1 de PyTorch a JAX
- Se distribuyó de forma eficiente mediante carga del modelo y sharding de parámetros
Sharding de parámetros en JAX
- Se usó la función de device mesh de JAX para distribuir eficientemente el modelo en 8 GPU AMD
- Se definieron reglas de sharding de parámetros para fragmentar cada dimensión de los tensores según los ejes de la malla
Implementación del entrenamiento con LoRA
- LoRA reduce la cantidad de parámetros entrenables al descomponer las actualizaciones de pesos en matrices de bajo rango
- Se implementó una capa
LoRADense que incluye parámetros de LoRA
- Los parámetros de LoRA se distribuyeron de forma eficiente para optimizar el uso de memoria y la eficiencia de cómputo
Conclusión
- La experiencia de ajustar finamente el modelo LLaMA 3.1 405B con GPU de AMD y JAX fue muy positiva
- Aprovechando las potentes capacidades de paralelismo de JAX y su enfoque independiente del hardware, el modelo se distribuyó de forma eficiente
- Esto demuestra que las GPU de AMD son una alternativa sólida para el entrenamiento de IA a gran escala
- Se puede revisar el código completo en el repositorio de GitHub y ejecutarlo directamente
Resumen de GN⁺
- Este artículo explica cómo entrenar eficientemente modelos de IA de gran escala usando GPU de AMD y JAX
- Destaca que el hardware de AMD es una alternativa rentable frente a NVIDIA
- El enfoque independiente del hardware de JAX mejora la portabilidad del código y facilita el mantenimiento
- Ofrece información útil y código práctico para quienes estén interesados en el entrenamiento de modelos a gran escala
- Proyectos con capacidades similares incluyen CUDA y PyTorch de NVIDIA
1 comentarios
Opiniones en Hacker News
Comparten resultados de haber ajustado finamente el modelo Llama3.1 405B usando JAX en 8 GPU AMD MI300x
Sugieren explorar cómo superar las limitaciones de memoria y ejecutar una versión compilada con JIT
Comparten experiencia sobre GPU AMD y el soporte de ROCm
Comparten experiencia experimentando con el modelo 405B desde el lado de inferencia
torch.cudano está tan malrocm:pytorches tan fácil como usar el contenedorrocm:jaxPreguntan por la ausencia de datos de rendimiento
Dudan de por qué Obsidian (la app para tomar notas) está haciendo esto
Piden a @dang que incluya el nombre de usuario en la URL