14 puntos por xguru 2024-08-19 | 8 comentarios | Compartir por WhatsApp
  • La razón por la que PyTorch provoca pérdida de productividad y desperdicio de tiempo de desarrollo no es que "el framework en sí sea malo, sino que no fue diseñado para ajustarse a los casos de uso que se le aplican hoy"

La filosofía de PyTorch

  • La filosofía de PyTorch es ser dinámico, fácil de depurar y Pythonic
  • En cambio, TensorFlow 1.x buscaba ser un framework estático pero de alto rendimiento, apoyándose fuertemente en el compilador XLA
  • Los desarrolladores de TensorFlow se dieron cuenta de que a la comunidad no le gustaba la API 1.x, así que decidieron usar Keras como interfaz principal y reducir el papel del compilador XLA
  • PyTorch se mantuvo fiel a sus raíces y, a diferencia del enfoque estático y diferido de TensorFlow, adoptó un enfoque más dinámico de "ejecución inmediata", donde torch.Tensor se evalúa al instante
  • Esto dio resultados y mucha investigación se trasladó a PyTorch
  • Con la llegada de GPT-3 en 2021, el rendimiento y la escalabilidad se volvieron preocupaciones principales
  • PyTorch respondió relativamente bien a esa demanda, pero como no fue diseñado con esa filosofía en mente, la deuda se fue acumulando y sus cimientos comenzaron a tambalearse
  • Los desarrolladores de PyTorch no quisieron aceptar ningún punto medio y eligieron perseguir dos caminos al mismo tiempo
    • usar el compilador XLA como backend base con gran rendimiento y estabilidad
    • construir la pila de torch.compile para dar a los usuarios la libertad de invocar el compilador cuando lo necesiten
  • La falta de una estrategia de largo plazo es un problema grave
  • PyTorch no quiere comprometerse con una filosofía centrada en el compilador (como JAX), pero tampoco se ve una buena alternativa
  • ¿Cuál es la solución que proponen los productos competidores para este problema?

Desarrollo basado en compilador en JAX

  • JAX aprovecha XLA, la poderosa pila de compilación de TensorFlow
  • XLA es un compilador potente, pero todo eso está abstraído para el usuario final
  • Siempre que una función sea pura (pure), se puede usar el decorador @jax.jit para compilarla con JIT y hacer que funcione con XLA
  • XLA se encarga detrás de escena de verificar que el grafo generado sea correcto, del particionador GSPMD que maneja el paralelismo automático con sharding en JAX, la optimización de grafos, la fusión de operadores y kernels, la planificación para ocultar latencia, el solapamiento asíncrono de comunicación y la generación de código para otros backends como triton
  • Mientras se respeten las restricciones de JAX, XLA se encarga automáticamente
  • Por ejemplo, al paralelizar no se necesitan primitivas de comunicación como torch.distributed.barrier()
  • El soporte para DDP se logra con código simple
  • El enfoque de XLA es que el cómputo sigue al sharding. Por tanto, si un arreglo de entrada está fragmentado a lo largo de cierto eje, XLA lo maneja automáticamente para los subcómputos
  • La idea de "desarrollo basado en compilador" es similar a cómo funciona el compilador de Rust
  • Limitaciones de PyTorch
    • Hay descontento con la decisión de los desarrolladores de PyTorch de integrar y depender de la pila de compilación para nuevas funciones, en lugar de mantener la filosofía central de flexibilidad y libertad
    • Según la hoja de ruta oficial de PyTorch 2.x, queda claro que hay un plan de largo plazo para integrar XLA por completo con Torch
    • Esa es una idea terrible. Es como decir que forzar código C++ dentro del compilador de Rust sería una mejor experiencia que usar Rust directamente
    • A diferencia de JAX, Torch no fue diseñado alrededor de XLA
    • Si PyTorch decide usar una pila de compilación basada en XLA, ¿no sería ideal un framework diseñado y construido específicamente alrededor de eso?
    • Incluso si PyTorch persigue un enfoque de "multi-backend" en el que pueda elegirse el backend de compilación deseado, ¿no empeoraría eso el problema de fragmentación y terminaría arruinando por completo la API al intentar respetar las limitaciones de todas las pilas de compilación?
    • Cualquiera que haya usado Torch/XLA en TPU sufre de un PTSD severo

Multi-Backend fracasó

  • PyTorch fracasa miserablemente al intentar hacerlo todo al mismo tiempo
  • La decisión de diseño de "multi-backend" agrava este problema de forma exponencial
  • En teoría suena como si se pudiera elegir la pila deseada, pero en la práctica es un enredo caótico de tracebacks incomprensibles y problemas de incompatibilidad
  • Restricciones entre backends y choques con la API de PyTorch
    • La dificultad no es solo hacer funcionar esos backends, sino que las restricciones que esperan no encajan bien con la API flexible y Pythonic de PyTorch
    • Hay un trade-off entre mantener la consistencia de la API y seguir las limitaciones del backend
    • Como resultado, los desarrolladores intentan depender más de la generación de código en lugar de integrarse o comprometerse realmente con un solo backend
  • La falta de estrategia de PyTorch
    • Como PyTorch rechaza hacer trade-offs significativos, cada decisión se siente como una concesión
    • No hay consistencia ni una estrategia general
    • Al final eso causa mucha frustración a los usuarios y se siente como una mezcolanza de funciones que no encajan bien entre sí
    • No hay forma más rápida de matar un ecosistema
  • Por qué no debería seguirse el enfoque de JAX
    • PyTorch no debería intentar seguir el enfoque de JAX de "compilador y backend integrados"
    • Porque JAX fue diseñado explícitamente para trabajar con XLA
    • Reemplazar el frontend de PyTorch por el de JAX no puede ser la estrategia
    • Es prácticamente imposible idear una API mejor que la de JAX sobre la base de XLA
    • No se critica a los desarrolladores por intentar ideas nuevas y distintas
    • Pero si PyTorch quiere resistir el paso del tiempo, debería enfocarse más en reforzar sus cimientos que en ofrecer funciones nuevas y llamativas que se desmoronan de inmediato fuera de las condiciones ideales de los tutoriales

La fragmentación de PyTorch y la programación funcional de JAX

  • La API funcional de JAX
    • Las funciones de JAX deben ser puras (pure), es decir, no deben tener efectos secundarios globales
    • Como funciones matemáticas, con los mismos datos siempre deben devolver la misma salida sin importar el contexto de ejecución
    • Gracias a esta filosofía de diseño, las funciones de JAX son componibles e interoperan bien entre sí
    • Se reduce la complejidad de desarrollo, y las funciones se definen con firmas específicas y tareas concretas bien definidas
    • Si los tipos se respetan, se garantiza que la función funcionará de inmediato
    • Esto encaja bien con el tipo de trabajo necesario en cómputo científico, especialmente en deep learning
  • Ejemplo de la API de optax
    • Gracias al enfoque funcional, en optax existe algo llamado "chain"
    • Esto incluye múltiples funciones que se aplican secuencialmente a los gradientes
    • El componente fundamental es GradientTransformation
    • Eso crea una API poderosa y expresiva
    • Por ejemplo, tareas como recortar gradientes, tomar la EMA de los gradientes o combinar optimizadores se vuelven muy simples
  • Ventajas del diseño funcional
    • Otro resultado genial del diseño funcional es vmap
    • Significa map "vectorizado" y describe exactamente lo que hace
    • Se puede mapear todo, y mientras sea con vmap, XLA fusiona y optimiza automáticamente
    • No hace falta pensar en la dimensión de batch al escribir funciones
    • Basta con aplicar vmap a todo el código
    • Eso significa que se necesitan menos operaciones tipo ein-*
    • Manipular tensores 2D/3D se vuelve más intuitivo y mucho más legible
    • Como solo hay que aislar componentes individuales y razonar sobre ellos, resulta más fácil escribir código complejo que funcione bien
    • Si se respetan las restricciones de pureza y se tienen las firmas correctas, se obtienen también todas las demás ventajas como la componibilidad
  • Problemas del ecosistema de PyTorch
    • En torch, sin importar la pila que se use (FSDP + multinodo + torch.compile, etc.), siempre existe la posibilidad de que algo se rompa
    • Muchas cosas tienen que funcionar bien juntas, y si cualquier componente falla, toca depurar hasta las 3 a. m.
    • Como no se pueden probar todas las combinaciones de las decenas de funciones que ofrece PyTorch, siempre habrá bugs no detectados durante el desarrollo
    • Es imposible escribir código que funcione bien sin un esfuerzo considerable
    • El ecosistema de torch se volvió muy inflado y lleno de bugs
    • Como no hay una abstracción compartida, aparecen nuevas bibliotecas y frameworks que no fueron diseñados para interactuar con otras "soluciones"
    • Eso se degrada rápidamente en un caos de dependencias y requirements.txt
    • El 70-80% de los issues de GitHub o discusiones en foros se debe simplemente a errores entre distintas bibliotecas
    • Casi no hay forma de resolverlo
  • Ausencia de solución
    • Este es un problema de OOP y de diseño
    • Se cree que un objeto básico y al estilo PyTorch como PyTree habría ayudado a construir una base común para la abstracción
    • Tampoco puede adoptarse el paradigma de programación funcional
    • Hacerlo rompería la compatibilidad hacia atrás de todas las bases de código existentes de torch mientras convergería hacia una versión peor en rendimiento de JAX
    • PyTorch parece estar completamente roto en este aspecto

La ventaja de JAX en reproducibilidad

  • Manejo de seeds
    • El manejo de seeds en PyTorch no es ideal
    • Por lo general hay que ejecutar varias líneas de código
    • Es fácil olvidarlo o configurarlo mal
    • JAX obliga a crear claves explícitas y pasarlas a cada función que necesite aleatoriedad
    • Este enfoque elimina por completo el problema porque el RNG siempre queda sembrado de manera estática
    • JAX tiene su propia versión de NumPy (jax.numpy), así que no hace falta establecer seeds por separado
    • Estas pequeñas decisiones de QoL pueden mejorar muchísimo la experiencia de uso del framework completo
  • Portabilidad
    • Uno de los mayores problemas al usar bases de código de PyTorch es la falta de portabilidad
    • Las bases de código escritas para CUDA/GPU no funcionan bien cuando se ejecutan en hardware no Nvidia como TPU, NPU o GPU de AMD
    • Es difícil portar código PyTorch escrito para un nodo a múltiples nodos
    • El multinodo suele requerir decenas de horas de desarrollo y cambios importantes en el código
    • El enfoque centrado en el compilador de JAX tiene ventajas en esto
    • XLA se encarga de cambiar entre backends de dispositivos y funciona bien en GPU/TPU/multinodo/multislice con cambios mínimos de código
    • Facilita a los proveedores de hardware dar soporte a sus dispositivos y hace más fácil cambiar entre dispositivos
    • No todo el mundo tiene acceso al mismo hardware, así que una base de código portable entre distintos tipos de hardware puede ser un pequeño paso para hacer el deep learning más accesible a personas principiantes e intermedias
  • Escalado automático
    • Una base de código que pueda escalar automáticamente bien por sí sola ayuda mucho a la reproducibilidad
    • Idealmente esto debería ocurrir automáticamente, sin importar los límites de red y con cambios mínimos en el código
    • JAX hace esto bien
    • Al escribir código en JAX no hace falta especificar primitivas de comunicación ni poner torch.distributed.barrier() por todos lados
    • XLA inserta eso automáticamente teniendo en cuenta el hardware disponible
    • Cualquier dispositivo que JAX pueda detectar se usa automáticamente, sin importar red, topología, configuración, etc.
    • Sincroniza y prepara el cómputo automáticamente y aplica pases de optimización para maximizar la ejecución asíncrona de kernels y minimizar la latencia
    • Lo único que una persona tiene que hacer es especificar el sharding del tensor que quiere distribuir entre dispositivos, como la dimensión de batch de los arreglos de entrada
    • Gracias al enfoque de XLA de que "el cómputo sigue al sharding", el resto se resuelve automáticamente
    • Eso permite ejecutar fácilmente, incluso como hobby, experimentos validados a escala para explorar y potencialmente iterar
    • Esto puede facilitar redescubrir ideas olvidadas y fomentar esos experimentos, ya que se pueden probar fácilmente a mayor escala con muy poco esfuerzo

Desventajas de JAX

  • Estructura de gobernanza
    • Actualmente XLA está bajo la gobernanza de TensorFlow
    • Ha habido discusiones sobre crear un organismo organizativo separado, similar al de PyTorch, pero no se han hecho muchos esfuerzos concretos
    • La confianza en Google no es muy alta por su reputación de descontinuar productos impopulares
    • Técnicamente JAX es un proyecto de DeepMind y tiene un significado central para el impulso general de IA de Google, pero parece que también traería grandes beneficios de largo plazo para todo el ecosistema
    • Un organismo de gobernanza independiente daría dirección al desarrollo del proyecto
    • Eso proporcionaría una estructura concreta y evitaría de una vez muchos problemas al separarlo de la notoria burocracia de Google
    • No es que JAX necesariamente requiera este tipo de estructura formal, pero sería bueno tener garantías de que su desarrollo continuará por mucho tiempo sin importar las decisiones de la alta dirección de Google
    • Eso ayudaría claramente a su adopción en empresas y grandes laboratorios de investigación que dudan en invertir recursos en integrar herramientas que algún día podrían dejar de mantenerse
  • La transición de XLA a open source
    • Durante mucho tiempo XLA fue un proyecto de código cerrado
    • Sin embargo, se hicieron esfuerzos para volverlo open source, y actualmente OpenXLA muestra un rendimiento muy superior al build interno de XLA
    • Aun así, sigue faltando documentación sobre el interior de XLA
    • La mayoría de los recursos son charlas en vivo y algún paper ocasional, y a menudo están desactualizados
    • Sería útil contar con una hoja de ruta pública y accesible para que la gente pueda seguir el progreso y contribuir especialmente en lo que le resulte interesante
    • También sería bueno tener mini posts de blog al estilo de Edward Yang que analicen cada etapa de la pila de compilación de XLA y expliquen los detalles, para que quienes lo usan en la práctica puedan evaluar mejor qué puede y qué no puede hacer XLA
    • Se entiende que eso consume muchos recursos y que tal vez podrían comunicarse mejor en otros formatos, pero la gente confía más en las herramientas cuando las entiende, y eso tendría un efecto positivo en cadena sobre todo el ecosistema, beneficiando a todos
  • Integración del ecosistema
    • flax es un dolor de cabeza en el ecosistema de JAX
    • Tiene una API poco intuitiva, una sintaxis críptica y es un infierno total para principiantes que vienen de PyTorch
    • Se recomienda usar equinox
    • Ha habido intentos del equipo de desarrollo por resolver los defectos de flax, pero al final es una pérdida de tiempo
    • Si se quiere una API al estilo de equinox, es mejor usar equinox
    • No hay muchas cosas que flax haga especialmente mejor, y no es difícil replicarlas con equinox
    • En este momento gran parte del ecosistema de JAX está diseñado alrededor de flax
    • Como equinox interactúa fundamentalmente con PyTree, es interoperable con todas las bibliotecas, aunque requiere algo de eqx.partition y filter
    • Se quiere cambiar el status quo. equinox debería tener soporte de primera clase en todas partes
    • Es una opinión polémica, pero este es un caso clásico de la falacia del costo hundido
    • equinox funciona mejor, de la forma en que el framework JAX siempre debió funcionar
    • Como se resume en la documentación de equinox, al comparar equinox con flax, equinox es mejor
    • Está bien que los responsables del ecosistema JAX reconozcan la popularidad de equinox y se ajusten en consecuencia, pero también se espera ver oficialmente más apoyo de parte de Google y del equipo de flax
    • Si quieres probar JAX, se recomienda usar equinox
  • Aristas peligrosas
    • Debido a decisiones de diseño de la API y a restricciones de XLA, JAX tiene "aristas peligrosas" con las que hay que tener cuidado
    • Esto está explicado de forma muy concisa en una documentación muy bien escrita
    • Conviene leerla al menos una vez antes de usar JAX
    • Como siempre, hacer RTFM puede ahorrar muchísimo tiempo y energía

Conclusión

  • Esta entrada de blog buscaba corregir el mito repetido una y otra vez de que PyTorch es lo más adecuado para cargas reales de investigación, especialmente en GPU. Ya no es así
  • De hecho, llega al extremo de sostener que portar todo el código de PyTorch a JAX sería enormemente beneficioso para todo el campo
    • El paralelismo automático, la reproducibilidad, una API funcional limpia y demás no son funciones triviales, y ayudarían mucho a muchas bases de código de investigación
  • Si quieres hacer que este campo sea aunque sea un poco mejor, considera reescribir tu base de código en JAX

8 comentarios

 
xguru 2024-08-25

El mundo sigue avanzando. jaja

Comparación entre PyTorch y TensorFlow en 2022

 
hilft 2024-08-21

Me quedaré con torch y onnx.

 
flrngel 2024-08-21

Lo escribió un estudiante de licenciatura... wow

 
cosine20 2024-08-21

PyTorch de verdad estaría muerto sin Huggingface, jajaja

 
lemonmint 2024-08-19

¡Larga vida a JAX! Lo probé hace poco y me gustó muchísimo la API de NNX.

 
stareta1202 2024-08-19

El mayor problema de JAX es que es de Google. Google es bastante famosa por abandonar proyectos open source (Tflite, Android Things, Dart, Angular, Bazel, etc.); incluso TensorFlow, en algún momento, empezó a recibir menos actualizaciones. En cambio, Torch nació en Facebook, que opera una enorme cantidad de open source, y se ha gestionado muy bien; de hecho, ya está bajo una fundación de Torch. Está claro que algunas críticas a Torch sí son válidas, pero cuando se trata de quién puede mantener sosteniblemente ese open source, JAX parece arrancar ya con un riesgo importante.

 
dalinaum 2024-08-20

Al menos parece que Dart seguirá viviendo bastante bien por un tiempo gracias a Flutter.

 
ilotoki0804 2024-08-20

Facebook, con React, Django y demás, al menos parece seguir contribuyendo con cierta lealtad (?) al stack tecnológico que usa, pero Google da la impresión de que en cuanto algo se vuelve un poco obsoleto lo tira como si fuera un trapo viejo...