DeepGEMM: kernel GEMM FP8 limpio y eficiente con escalado granular
(github.com/deepseek-ai)DeepGEMM
DeepGEMM es una biblioteca para multiplicación general de matrices FP8 (GEMM) que admite el escalado granular propuesto en DeepSeek-V3. Esta biblioteca soporta GEMM agrupado general y para Mix-of-Experts (MoE), está escrita en CUDA y no requiere compilación durante la instalación. Soporta los tensor cores NVIDIA Hopper y usa una acumulación en dos etapas con CUDA cores para resolver la imprecisión de la acumulación FP8 en tensor cores. Aprovecha parcialmente conceptos de CUTLASS y CuTe, pero mantiene la simplicidad al minimizar la dependencia de plantillas o álgebra. Con una sola función de kernel central de unas 300 líneas de código, es un buen recurso para aprender sobre multiplicación de matrices FP8 en Hopper y técnicas de optimización. A pesar de su diseño ligero, iguala o supera el rendimiento de bibliotecas ajustadas por expertos en diversas formas de matrices.
Rendimiento
Se probaron en H800 SXM5 con NVCC 12.8 todas las formas que pueden usarse en la inferencia de DeepSeek-V3/R1. Todas las métricas de mejora de velocidad se calcularon comparando con una implementación optimizada internamente basada en CUTLASS 3.6. Algunas formas pueden tener un rendimiento deficiente, y se agradecen PRs de optimización.
GEMM general (modelo denso)
- Las mediciones del rendimiento de DeepGEMM en varios tamaños de matriz muestran mejoras de velocidad de hasta 2.7x en tamaños específicos.
GEMM agrupado para modelos MoE (layout continuo)
- Dependiendo del número de grupos y del tamaño de matriz de cada grupo, muestra mejoras de velocidad de hasta 1.2x.
GEMM agrupado para modelos MoE (layout con máscara)
- Usando el layout con máscara, muestra mejoras de velocidad de hasta 1.2x.
Inicio rápido
Requisitos
- GPU con arquitectura Hopper, se requiere soporte para
sm_90a - Python 3.8 o superior
- CUDA 12.3 o superior (se recomienda 12.8 o superior para el mejor rendimiento)
- PyTorch 2.1 o superior
- CUTLASS 3.6 o superior
Desarrollo
- Describe el proceso de desarrollo, incluyendo clonar submódulos, crear enlaces simbólicos, compilación JIT y pruebas de todas las implementaciones GEMM.
Instalación
- Se puede importar y usar
deep_gemmen proyectos de Python.
Interfaz
Precauciones
- Esta biblioteca solo incluye kernels GEMM y solo soporta formato NT. Las operaciones de transposición u otras tareas de casting FP8 deben implementarse por separado.
GEMM denso general (no agrupado)
- Proporciona funciones para realizar GEMM FP8 básico no agrupado.
GEMM agrupado (layout continuo)
- Está diseñado para escenarios en modelos MoE donde los expertos comparten la misma forma.
GEMM agrupado (layout con máscara)
- En la etapa de decodificación de inferencia, proporciona un tensor de máscara para calcular solo las partes válidas.
Utilidades
- Proporciona varias funciones utilitarias y variables de entorno para ayudar a optimizar el rendimiento.
Optimización
Especialización persistente de warps
- Sigue el diseño de CUTLASS, superponiendo movimiento de datos, instrucciones MMA de tensor cores y promoción con CUDA cores.
Funcionalidad TMA de Hopper
- Aprovecha TMA para acelerar el movimiento de datos.
Optimizaciones detalladas comunes
- Mejora el rendimiento mediante diversas técnicas de optimización.
Scheduler de bloques unificado y optimizado
- Proporciona un scheduler para todos los kernels no agrupados y agrupados.
Diseño JIT completo
- Mejora el rendimiento mediante un diseño JIT que no requiere compilación durante la instalación.
Tamaños de bloque no alineados
- Soporta tamaños de bloque no alineados para maximizar la utilización de SM en ciertas formas.
Intercalado de FFMA SASS
- Mejora el paralelismo a nivel de warp modificando instrucciones FFMA para aumentar el rendimiento.
Agradecimientos
- DeepGEMM está inspirado en el proyecto CUTLASS, y expresa agradecimiento y respeto a sus desarrolladores.
Licencia
- Se publica bajo la licencia MIT.
Aún no hay comentarios.