- Los modelos de difusión se usan no solo para la generación de imágenes, sino también en problemas que requieren muestreo de distribuciones multimodales, como audio, video, 3D, diseño de proteínas y planificación de trayectorias robóticas; este tutorial conecta entrenamiento y muestreo desde una perspectiva de optimización
- El proceso de entrenamiento crea datos con ruido mediante (x_\sigma=x_0+\sigma\epsilon), y minimiza el error cuadrático medio para que la red neuronal (\epsilon_\theta(x,\sigma)) prediga la dirección del ruido
- El denoiser entrenado puede interpretarse como una proyección aproximada sobre el conjunto de datos (\mathcal{K}), y el denoiser ideal se relaciona con el gradiente de la función de distancia cuadrada suavizada por (\sigma)
- El muestreo DDIM puede verse como un descenso de gradiente aproximado sobre (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2), y el schedule de (\sigma_t) determina la cantidad de iteraciones y el costo de evaluar el denoiser
- Al combinar actualizaciones basadas en estimación de gradiente con la adición de ruido, DDIM, DDPM y el sampler mejorado de los autores pueden tratarse en conjunto mediante los parámetros
gamymu, con ejemplos en modelos toy y latent diffusion
Modelos de difusión desde la perspectiva de optimización
- Los modelos de difusión destacan para generar muestras de distribuciones multimodales, y se aplican no solo a herramientas de texto a imagen como Stable Diffusion, sino también a audio, video, generación 3D, diseño de proteínas y planificación de trayectorias robóticas
- La base teórica del tutorial es la interpretación de optimización del paper de ICML 2024 y un paper relacionado
- La implementación toma como referencia principal
smalldiffusion, y el código del texto está simplificado con fines didácticos frente a la librería original
Entrenamiento: predicción de la dirección del ruido
- Un modelo de difusión aprende el conjunto de datos (\mathcal{K}) a partir de ejemplos de entrenamiento, con el objetivo de generar muestras de ese conjunto
- En imágenes, (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) es el conjunto de valores de píxeles que corresponde a imágenes realistas
- El mismo marco también se aplica a audio, video, trayectorias robóticas y dominios discretos como texto
- El procedimiento de entrenamiento puede verse en tres pasos
- Se muestrean (x_0 \sim \mathcal{K}), (\sigma) y (\epsilon \sim N(0,I))
- Se construye dato con ruido con (x_\sigma=x_0+\sigma\epsilon)
- Se minimiza la pérdida cuadrática para que (\epsilon_\theta(x_\sigma,\sigma)) prediga (\epsilon)
- En el código,
training_loopgenerasigmayepspara cada batchx0congenerate_train_sample, y optimiza el MSE entre la salida demodel(x0 + sigma * eps, sigma)yeps - En lugar de muestrear (\sigma) uniformemente en un intervalo continuo, se toma de un schedule de (\sigma) discretizado en (N) valores
- La clase
Scheduleenvuelve la lista desigmasposibles y muestrea valores por batch durante el entrenamiento - El ejemplo del texto usa
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMes un schedule para modelos de difusión en espacio de píxeles, yScheduleLDMpara modelos de latent diffusion como Stable Diffusion
- La clase
Ejemplo toy de Swissroll
- El dataset toy es un conjunto de puntos en espiral usado en uno de los primeros papers de difusión, Sohl-Dickstein et al. 2015, con (\mathcal{K}\subset\mathbb{R}^2)
- En datasets simples, el denoiser se implementa como un MLP
- La entrada concatena (x\in\mathbb{R}^2) con un embedding bidimensional de (\sigma)
- La salida es una predicción del ruido (\epsilon\in\mathbb{R}^2)
- Muchos modelos de difusión usan sinusoidal positional embedding para (\sigma), pero en este ejemplo también funciona bien un embedding bidimensional simple
- La configuración de entrenamiento del ejemplo usa
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)yepochs=15000 - El denoiser entrenado puede visualizarse como un campo vectorial dibujando (x-\sigma\epsilon_\theta(x,\sigma))
- Cuando (\sigma) es grande, el denoiser tiende a predecir la media de los datos
- Cuando (\sigma) es pequeño y la entrada (x) está cerca de los datos, predice puntos de datos reales
Interpretar el denoising como proyección
- La función de distancia al conjunto de datos (\mathcal{K}) se define como (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- La proyección de (x), (\mathrm{proj}_{\mathcal{K}}(x)), es el conjunto de puntos en (\mathcal{K}) que alcanza esa distancia
- Si (\mathcal{K}) es un conjunto cerrado, (x\notin\mathcal{K}) y la proyección es única, el gradiente de la función de distancia cuadrada es (x-\mathrm{proj}_{\mathcal{K}}(x))
- Como la función de distancia (\mathrm{dist}_{\mathcal{K}}) no es diferenciable en todas partes, se introduce una versión suavizada por (\sigma) de la distancia cuadrada usando softmin en lugar de
min - El gradiente de la función de distancia suavizada apunta hacia el promedio ponderado de los puntos de (\mathcal{K}), con pesos determinados por (x)
Denoiser ideal y modelo de error relativo
- El denoiser ideal (\epsilon^*) es el que minimiza exactamente la pérdida de entrenamiento para un (\sigma) dado
- Si los datos siguen una distribución uniforme discreta sobre un conjunto finito (\mathcal{K}), el denoiser ideal puede expresarse en forma cerrada
- El peso de cada punto de datos se determina según la distancia entre (x_\sigma) y ese punto
- En datasets pequeños puede calcularse directamente con
IdealDenoiser
- En datos toy, el denoiser ideal apunta a la media de los datos cuando (\sigma) es grande, y al punto de datos más cercano cuando (\sigma) es pequeño
- El teorema clave establece que para todo (\sigma>0) y (x\in\mathbb{R}^n), se cumple (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma))
- El modelo de error relativo usa la condición de que (x-\sigma\epsilon_\theta(x,\sigma)) aproxime bien a (\mathrm{proj}_{\mathcal{K}}(x))
- Se aplica cuando (\sqrt{n}\sigma) estima (\mathrm{dist}_{\mathcal{K}}(x)) dentro de un factor constante
- Se asume que el error está acotado por (\eta\mathrm{dist}_{\mathcal{K}}(x))
- Con ruido bajo, bajo la manifold hypothesis, la mayor parte del ruido adicional es ortogonal a la variedad de datos, por lo que el denoising aproxima una proyección
- Con ruido alto, si (\sigma) es mayor que el diámetro de (\mathcal{K}), incluso un denoiser que predice el promedio ponderado de los datos tiene error relativo pequeño
- CIFAR-10 tiene un tamaño que permite calcular el denoiser ideal, y en los experimentos el error relativo entre la proyección exacta y la salida del denoiser ideal sobre las trayectorias de muestreo resulta pequeño
Muestreo: denoising iterativo y DDIM
- Una vez entrenado el denoiser, a partir de (x_t) con ruido y nivel de ruido (\sigma_t), se predice (x_0) como (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
- El punto inicial se obtiene tomando (\sigma_T) suficientemente grande respecto al diámetro de (\mathcal{K}), y muestreando (x_T) independientemente desde (N(0,\sigma_T)) para que quede lejos de (\mathcal{K})
- Con ruido alto, una sola llamada al denoiser puede tener gran error absoluto aunque el error relativo sea pequeño, y la predicción del denoiser ideal queda cerca de la media de los datos
- Por eso el muestreo llama repetidamente al denoiser siguiendo un schedule (\sigma_t), construyendo una secuencia (x_T,\ldots,x_0)
- La actualización (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) es equivalente al algoritmo determinista de muestreo DDIM tras un cambio de coordenadas
- La demostración de equivalencia con DDIM está en el Apéndice A del paper
DDIM visto como minimización de distancia
- DDIM puede interpretarse como un descenso de gradiente aproximado sobre (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
- El tamaño de paso es (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) se estima con (\epsilon_\theta(x_t,\sigma_t))
- El schedule de (\sigma_t) determina la cantidad y el tamaño de los pasos de gradiente durante el muestreo
- Si hay muy pocos pasos, (\mathrm{dist}_{\mathcal{K}}(x_t)) puede no disminuir y no converger
- Si se usan muchos pasos pequeños, aumenta el número de evaluaciones del denoiser y el costo computacional
- Un admissible schedule es un schedule en el que (\sqrt{n}\sigma_t) se mantiene dentro de un factor constante de (\mathrm{dist}_{\mathcal{K}}(x_t)) en cada iteración
- Una secuencia log-linear de (\sigma_t) que decrece geométricamente es un admissible schedule
- Según el teorema, si para el (x_t) generado por DDIM existe (\nabla\mathrm{dist}{\mathcal{K}}(x)) y (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), entonces (x_t) se genera mediante descenso de gradiente sobre la función de distancia cuadrada y se mantiene (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t)
- En el ejemplo toy, se implementa un sampler DDIM de 20 pasos submuestreando el schedule log-linear original; la mayoría de las muestras quedan cerca de los datos originales, aunque aún hay margen de mejora
Sampler mejorado basado en estimación de gradiente
- Aprovechando que (\nabla\mathrm{dist}{\mathcal{K}}(x)) es invariante entre (x) y (\mathrm{proj}{\mathcal{K}}(x)), se usa una actualización que mezcla la estimación actual con la anterior
- La actualización (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) corrige el error del paso previo usando la estimación actual
- En muestras del modelo toy, este método converge más rápido que DDIM y genera muestras más cercanas a los datos originales
- Frente a DDIM, este sampler puede interpretarse como una versión con momentum; la trayectoria puede hacer overshoot, pero también converger más rápido
- Agregar ruido durante el proceso de generación mejora empíricamente la calidad del muestreo
- Para conservar el schedule original de (\sigma_t), primero se hace denoise hasta un (\sigma_{t'}) menor y luego se vuelve a agregar ruido (w_t\sim N(0,I))
- Cuando (\mu=\frac{1}{2}), se recupera exactamente el sampler DDPM
- La actualización completa (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) generaliza tres samplers
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - Sampler por estimación de gradiente:
gam=2, mu=0
- DDIM:
Modelos más grandes y materiales de referencia
- El código de entrenamiento anterior puede usarse no solo con datos toy, sino también para entrenar desde cero modelos de difusión de imágenes
- El ejemplo de FashionMNIST se ofrece como ejemplo entrenado en el dataset FashionMNIST, y logra el segundo mejor puntaje FID en el leaderboard de Papers with Code
- El código de muestreo también puede usarse sin cambios con modelos de latent diffusion preentrenados
- El ejemplo usa
ScheduleLDM(1000)yModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - La condición de texto se fija en
An astronaut riding a horse, se muestrea con 50 pasos de (\sigma) y luego se decodifica el latent
- El ejemplo usa
- El efecto del término de momentum (\gamma) se compara visualmente en generación de texto a imagen de alta resolución
- Material adicional recomendado
- What are diffusion models: introducción a los modelos de difusión desde una perspectiva de tiempo discreto que revierte un proceso de Markov
- Generative modeling by estimating gradients of the data distribution: introducción a los modelos de difusión desde una perspectiva de tiempo continuo que revierte ecuaciones diferenciales estocásticas
- The annotated diffusion model: explicación detallada de la implementación de un modelo de difusión en PyTorch
1 comentarios
Opiniones en Hacker News
Si tienen preguntas, puedo responderlas.
Me gustó especialmente la discusión sobre las trayectorias, porque motiva a entender una parte que a muchos les cuesta en temas como los schedulers. Aunque no es tan completo como los textos de Song o Lilian, es mucho más accesible, así que pienso recomendárselo a otras personas.
Como referencia, un amigo escribió hace tiempo una implementación mínima de difusión que, desde la perspectiva de DDPM, es un poco más “completa”, y me resultó útil: https://github.com/VSehwag/minimal-diffusion/
Como alguien que ha experimentado un poco con el procedimiento de sampling en Stable Diffusion, también me habría gustado ver una comparación de tiempo de convergencia y número de pasos frente a DDIM. Me pregunto si hay alguna relación entre momentum, convergencia y error. Por ejemplo, sería interesante una comparación del estilo: si un sampler con momentum de 16 pasos es casi equivalente a DDIM de 20 pasos ± un término de error.
get_sigma_embeds(batches, sigma)no usa el primer argumento. Me pregunto si la intención era hacer broadcast desigmaa la forma(batches, 1).Entra mucho más a fondo en los detalles matemáticos y viene con una implementación mínima de menos de 500 líneas, muy fácil de entender.
Sería genial que esto se extendiera también a la versión de transformers de difusión que impulsa a Sora y otros modelos de generación de video. Combinando este artículo con https://jaykmody.com/blog/gpt-from-scratch/ se podría hacer una introducción de “transformers de difusión desde cero”.
En cambio, si de verdad quieres profundizar, recomiendo leer los trabajos de Kingma, Gao, Ricky Tian Qi Chen y los estudiantes de Max Welling (Tomczak es posdoc, Hoogeboom, etc.), además del héroe poco reconocido Aapo Hyvärinen. Un ejemplo del lado relativamente más ligero de Kingma & Gao, y relacionado también con el paper de SD3, está aquí: https://arxiv.org/abs/2303.00848
Lo lamentable es que hay una gran dependencia de conocer y entender trabajos previos, lo que reduce la accesibilidad, aunque también es difícil llamar a eso una crítica significativa. Es investigación, no material educativo para el público general.
n_embd; el proceso de difusión en sí puede quedar igual.[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
Desde nuestra perspectiva, la razón por la que los modelos de difusión son fáciles de entrenar es que, en lugar de predecir el gradiente de la función de distancia exacta, usan un objetivo de entrenamiento que predice el gradiente de una función de distancia suavizada. El sampling de modelos de difusión se parece a dar varios pasos de gradiente aproximados.
Para entender más a fondo los modelos de difusión, recomiendo leer todos estos posts y aprender las distintas interpretaciones.
Dicho eso, el enfoque de este artículo parece permitir experimentos más interesantes, como el análisis de errores del eliminador de ruido.
[1] https://arxiv.org/pdf/2305.03486.pdf
Por ejemplo, ¿por qué a los generadores de imágenes les cuesta crear teclas de piano? Para generar la estructura en la que las teclas negras alternan en grupos de dos y de tres, parece que habría que representar mejor restricciones de distancia intermedia.