Reimplementan Stable Diffusion 3.5 desde cero en PyTorch puro
(github.com/yousef-rafat)- El proyecto miniDiffusion es un proyecto open source que reimplementa desde cero el modelo Stable Diffusion 3.5 usando únicamente PyTorch
- La estructura de este proyecto se caracteriza por estar enfocada en fines educativos y en experimentación y hacking
- Todo el código base tiene alrededor de 2800 líneas y está compuesto con código mínimo, desde el VAE hasta el DiT, además de scripts de entrenamiento y de datasets
- Entre sus componentes principales se incluyen VAE, CLIP, codificadores de texto T5, transformer de difusión multimodal, joint attention y más
- Aún incluye funcionalidades experimentales, por lo que necesita más pruebas
Introducción al proyecto miniDiffusion
miniDiffusion es un proyecto open source que reimplementa las funciones principales de Stable Diffusion 3.5 usando solo PyTorch
Frente a Stable Diffusion 3.5 existente, este proyecto tiene las siguientes ventajas
- El código base tiene unas 2,800 líneas, por lo que es pequeño y muy adecuado para analizar su estructura directamente y aprender de ella
- Puede aprovecharse de forma muy útil para diversos experimentos de aprendizaje automático y hacking de modelos
- Tiene muy pocas dependencias y usa solo un conjunto mínimo de librerías
Estructura principal y archivos de configuración
- dit.py : implementación principal del modelo Stable Diffusion
- dit_components.py : embeddings, normalización, patch embedding y funciones auxiliares de DiT
- attention.py : implementación del algoritmo de Joint Attention
- noise.py : incluye el scheduler Euler ODE para Rectified Flow
- t5_encoder.py, clip.py : implementación de los codificadores de texto T5 y CLIP
- tokenizer.py : implementación de tokenizadores Byte-Pair y Unigram
- metrics.py : implementación de la métrica de evaluación FID (Fréchet inception distance)
- common.py : proporciona funciones auxiliares necesarias para el entrenamiento
- common_ds.py : implementación de un dataset iterable que convierte imágenes en datos de entrenamiento para DiT
- carpeta model : guarda checkpoints del modelo y logs después del entrenamiento
- carpeta encoders : guarda checkpoints de módulos separados como VAE y CLIP
⚠️ Funciones experimentales y necesidad de pruebas miniDiffusion aún incluye funciones experimentales y requiere más pruebas
Desglose detallado por funcionalidad principal
Core Image Generation Modules
- Implementación de VAE, CLIP y codificadores de texto T5
- Implementación de tokenizadores Byte-Pair y Unigram
Componentes de SD3
- Modelo Transformer de Difusión Multimodal
- Implementación de Flow-Matching Euler Scheduler
- Logit-Normal Sampling
- Incorporación del algoritmo de Joint Attention
Scripts de entrenamiento e inferencia del modelo
- Incluye scripts de entrenamiento e inferencia para SD3 (Stable Diffusion 3.5)
Licencia
- Se publica bajo la licencia MIT y fue creado con fines educativos y experimentales
Significado y ventajas de este proyecto open source
- Permite entrenar y hacer hacking directamente sobre una arquitectura de modelo moderno de generación de imágenes al nivel de Stable Diffusion 3.5 usando PyTorch puro
- Su código es conciso e independiente, por lo que está optimizado para análisis de arquitectura / ajuste de modelos / investigación de nuevos algoritmos
- Permite practicar directamente técnicas modernas de multimodalidad, transformers y attention
- Proporciona una base para experimentar con seguridad al margen de proyectos comerciales
1 comentarios
Comentarios de Hacker News
La implementación de referencia de Flux tiene una estructura realmente minimalista, así que si a alguien le interesa vale la pena echarle un vistazo
GitHub de Flux
El proyecto minRF tiene la ventaja de que permite empezar fácilmente a entrenar modelos pequeños de difusión usando rectified flow
GitHub de minRF
La implementación de referencia de Stable Diffusion 3.5 también está escrita de forma bastante concisa, así que sirve bien como referencia
GitHub de SD 3.5
Las implementaciones de referencia muchas veces no están bien mantenidas y suelen tener bastantes bugs
cudagraphsy similaresMe surgió la duda de si el proyecto miniDiffusion significa que usa el modelo Stable Diffusion 3.5
Código relacionado
El dataset de entrenamiento es muy pequeño e incluye solo fotos relacionadas con moda
Dataset de moda
Ese dataset es para practicar el fine-tuning de un modelo de difusión
Me pregunto si usar PyTorch puro trae ventajas de rendimiento en GPUs que no sean de NVIDIA, o si PyTorch está tan optimizado para CUDA que otros fabricantes de GPU no pueden competir
PyTorch funciona bastante bien también en Apple Silicon
También es posible correr cargas de trabajo de ML en dispositivos no NVIDIA como AMD a través de Vulkan
El soporte de PyTorch para ROCm avanza muy lentamente, y aun cuando logras hacerlo funcionar, es lento
PyTorch sí funciona bien en ROCm, pero no sé si llega a funcionar al nivel de ser realmente "equivalente"
En el código de PyTorch, en vez de
sugieren que estaría bien probar algo como
Parece un buen material para gente que está aprendiendo
Me pregunto si hay algún tutorial o guía que también pueda seguir un principiante
En fast.ai hay una clase donde implementan Stable Diffusion directamente
Me surge la duda de si esto significa que se puede usar Stable Diffusion sin restricciones de licencia
La verdad, aunque me da un poco de pena, me pregunto qué es lo nuevo que ganamos al comparar el antes y el después de que existieran estos repositorios
Personalmente he evitado meterme a crear modelos y más bien he visto los resultados desde fuera
Ya daba por hecho que antes también existían scripts públicos de inferencia/entrenamiento basados en PyTorch
Al menos pensaba que los scripts de inferencia venían junto con la distribución del modelo, y que también habría scripts de fine-tuning/entrenamiento
No me queda claro si este proyecto es una reescritura tipo "clean room" o "dirty room", o si incluso el código PyTorch existente era tan complejo por todo lo CUDA/C que una versión en PyTorch puro sí tiene mucho valor
En cualquier caso, no lo tengo claro y agradecería si alguien lo puede explicar
El valor principal de este proyecto es que es una "implementación con dependencias mínimas"
transformers, así que en trabajo real esto sí resulta bastante problemáticoStability AI distribuye los modelos Stable Diffusion bajo la Stability AI Community License, que a diferencia de MIT no es "totalmente libre"
Cuando pienso en SD 3.5, o en cualquier versión, veo como lo más importante la parte de los pesos generados durante el entrenamiento
Me pregunto qué tan utilizable en la práctica es el código fuente académico original que publicó el grupo CompViz de Ludwig Maximilian University
Me pregunto si la implementación de diffusion transformer (DiT) aquí implementa correctamente la cross-token attention como en la versión completa de SD 3.5, o si la simplificaron para hacer el código más legible