-
La importancia de Attention
- Attention es la capa central de la arquitectura Transformer y genera cuellos de botella en los modelos de lenguaje grandes y en aplicaciones de contexto largo.
- FlashAttention y FlashAttention-2 fueron pioneros en un enfoque para acelerar Attention en GPU minimizando las lecturas y escrituras de memoria.
- Gracias a esto, la longitud de contexto de los LLM aumentó de forma considerable.
-
Tecnologías principales de FlashAttention-3
- Aprovechamiento de la asincronía: aprovecha la asincronía de Tensor Cores y TMA para superponer todo el cálculo y el movimiento de datos.
- Operaciones por bloques: intercala la multiplicación de matrices por bloques y las operaciones de softmax.
- Procesamiento de baja precisión: mejora el rendimiento aprovechando el soporte de baja precisión FP8.
-
Mejoras de rendimiento de FlashAttention-3
- Eficiencia en el uso de la GPU: aprovecha hasta el 75% del rendimiento máximo de la GPU H100 y es entre 1.5 y 2 veces más rápido que la versión anterior.
- Rendimiento con baja precisión: usa FP8 para aumentar la velocidad de procesamiento y reducir el uso de memoria.
- Procesamiento de contexto largo: acelera el mecanismo de Attention para procesar textos más largos de forma eficiente.
-
Resumen de FlashAttention
- FlashAttention reorganiza el cálculo de Attention y aprovecha el tiling y la recomputación para aumentar mucho la velocidad y reducir el uso de memoria.
- Mediante tiling, carga bloques de entrada, ejecuta Attention sobre esos bloques y luego actualiza la salida.
- Reduce la cantidad de lecturas y escrituras de memoria al no escribir en memoria la matriz intermedia de Attention.
-
Nuevas funciones de hardware de las GPU Hopper
- WGMMA: ofrece alto throughput aprovechando los nuevos Tensor Cores.
- TMA: unidad de hardware que acelera la transferencia de datos entre la memoria global y la memoria compartida.
- FP8 de baja precisión: usa FP8 para duplicar el throughput de Tensor Core.
-
Asincronía: superposición de GEMM y Softmax
- Necesidad de la superposición: ejecuta GEMM y softmax en paralelo para maximizar el rendimiento.
- Programación ping-pong: dos grupos de warps alternan entre GEMM y softmax para mejorar el rendimiento.
- Superposición dentro del grupo de warps: ejecuta GEMM y softmax en paralelo dentro del mismo grupo de warps para aumentar el throughput.
-
Baja precisión: reducción del error de cuantización con procesamiento incoherente
- Procesamiento incoherente: usa la transformada de Hadamard para reducir el error de cuantización.
- Resultados experimentales: el procesamiento incoherente reduce el error de cuantización en 2.6 veces.
-
Benchmark de Attention
- FP16: alrededor de 1.6 a 1.8 veces más rápido que FlashAttention-2.
- FP8: alcanza hasta 1.2 PFLOPS.
Resumen de GN⁺
- FlashAttention-3 mejora de forma importante el rendimiento del mecanismo de Attention al aprovechar nuevas funciones de hardware de la GPU.
- Puede procesar contexto largo de manera eficiente, maximizando el rendimiento de los modelos de lenguaje grandes.
- Es muy probable que se integre en frameworks principales como PyTorch, por lo que tendrá un gran impacto en la investigación y las aplicaciones de IA.
- Entre los proyectos con funciones similares están Triton y cuDNN.
1 comentarios
Opiniones en Hacker News
Parece que Tri Dao empezó a trabajar en FA3 desde abril de 2022
Me pregunto qué tanto depende el algoritmo Flash Attention del hardware
Me pregunto si un compilador podría encontrar por sí solo optimizaciones como FlashAttention
Piden que quien quiera portarlo a ROCm/AMD MI300x se ponga en contacto
TMA (Tensor Memory Accelerator) es una unidad de hardware que acelera la transferencia de datos entre memoria global y memoria compartida
FlashAttention-3 está optimizado para GPUs Hopper (por ejemplo, H100)
Se menciona que funciones de activación como sigmoid son muy lentas en los LLM modernos
Me pregunto por qué Flash Attention es 5 veces más lento con enmascaramiento variable que sin él
Me pregunto si FlashAttention puede reemplazar la operación de attention en los LLM
Se necesita hardware costoso