- Cuando un LLM de propósito general resulta excesivo para tareas especializadas, ajustar finamente Llama-2 directamente puede mejorar a la vez calidad, costo y latencia con un modelo más pequeño y barato
- Tras el ajuste fino, Llama-2 13B elevó la precisión de representaciones funcionales de ViGGO de 58%→98%, la generación de SQL de 42%→89% y GSM8k de 28%→47%
- En tareas donde el formato de salida es importante, como ViGGO y generación de SQL, los modelos pequeños de Llama-2 superaron a GPT-4, pero en razonamiento matemático no alcanzaron su nivel
- Los experimentos se realizaron con scripts basados en Ray Train, Ray Data, DeepSpeed y Accelerate, y entrenaron 7B·13B en 16xA10G y 70B en 32xA10G
- La clave de la mejora no fue tanto el tamaño del modelo como la calidad de los datos y el pipeline de evaluación; hay que comparar por tarea el trade-off de costo y calidad entre prompt engineering y ajuste fino
Efecto del ajuste fino en tres tareas
- Los grandes modelos generalistas como GPT-4 y Claude-2 son útiles para prototipado rápido, pero para requisitos acotados como resumir o clasificar tickets de soporte pueden ser excesivos en costo y rendimiento
- El experimento comparó cuánto mejoran los modelos Llama-2 al aplicarles ajuste fino de parámetros completos para adaptarlos a tres tareas de tipo real
- ViGGO: extracción de representaciones funcionales desde texto no estructurado
- SQL-create-context: generación de SQL a partir de lenguaje natural y contexto
CREATE TABLE - GSM8k: resolución de problemas matemáticos de nivel primaria
- En Llama-2 13B, los cambios de precisión fueron los siguientes
- Representación funcional de ViGGO: 58% → 98%
- Generación de SQL: 42% → 89%
- GSM8k: 28% → 47%
- En ViGGO y generación de SQL, los modelos pequeños de Llama-2 lograron mejores resultados que GPT-4, mientras que en tareas de razonamiento matemático como GSM8k no llegaron a su rendimiento ni después del ajuste fino
Método de ajuste fino e infraestructura de entrenamiento
- En las tres tareas se usó ajuste fino estándar de parámetros completos
- El entrenamiento se hizo con predicción del siguiente token
- Todos los parámetros del modelo fueron objetivo de actualización por gradiente
- LoRA y los enfoques que congelan parte de los bloques transformer quedaron fuera del alcance del experimento
- Los scripts experimentales se construyeron sobre Ray Train, Ray Data, DeepSpeed y Accelerate
- Soportan ejecuciones con Llama-2 7B, 13B y 70B
- TorchTrainer de Ray Train distribuye el loop de entrenamiento entre varios procesos worker y recursos GPU
- Ray Train maneja el sharding de datos, y cada worker accede a su fragmento asignado con
session.get_dataset_shard("train"),session.get_dataset_shard("valid")
- El sharding del modelo se resolvió con DeepSpeed ZeRO stage 3 y offloading del estado del optimizador
- Como los fragmentos del modelo están repartidos entre varios workers, cuando se necesita acceder al modelo completo, por ejemplo para guardar checkpoints, hay que desempaquetarlo con
accelerator.unwrap_model(model)
- Como los fragmentos del modelo están repartidos entre varios workers, cuando se necesita acceder al modelo completo, por ejemplo para guardar checkpoints, hay que desempaquetarlo con
- Los recursos de cómputo fueron los siguientes
- 7B·13B: 16xA10G
- 70B: 32xA10G, en 4 instancias
g5.48xlarge - Con Ray, no es obligatorio usar A100 para hacer ajuste fino de parámetros completos
- El entrenamiento se ejecutó hasta 10 epochs como máximo y se eligió el checkpoint con menor perplexity en el conjunto de validación
Fijar la estructura de entrada y salida con tokens especiales
- Los datos de ajuste fino expresan la estructura de la tarea con tokens especiales en lugar de prompts tipo instrucción
- Ejemplo:
<START_Q>{question}<END_Q><START_A>{answer}<END_A>
- Ejemplo:
- Los tokens especiales ayudan al modelo a distinguir las secciones de entrada y salida y a aprender con claridad dónde debe detener la generación
- En el ejemplo,
<END_A>se define como stopping token para detener la salida al completar la tarea
- En el ejemplo,
- El tokenizer de Llama produce por defecto 32,000 IDs de tokens
- Al agregar cuatro tokens especiales, pasa a producir 32,004 IDs
<START_Q>recibe un nuevo ID como 32000 y<END_Q>uno como 32001
- El script agrega los tokens especiales con
tokenizer.add_tokens(special_tokens, special_tokens=True)y crea nuevos parámetros entrenables conmodel.resize_token_embeddings(len(tokenizer))
ViGGO: convertir texto no estructurado en representación funcional
- ViGGO es originalmente un dataset en inglés para convertir representaciones funcionales basadas en atributo-valor en texto natural; en el experimento se invirtió la dirección para transformar texto no estructurado en una representación funcional estructurada
- El dominio es opiniones sobre videojuegos
- La representación resultante puede usarse para indexación y aplicaciones posteriores
- El modelo debe generar la función y los valores de atributos adecuados para cada oración
- Entre las funciones candidatas están
inform,request,give_opinion,confirm,verify_attribute,suggest,request_explanation,recommend,request_attribute - Entre los atributos candidatos están
name,release_year,esrb,genres,platforms,available_on_steam,has_linux_release,has_mac_release,specifier,rating,player_perspective,has_multiplayer,developer,exp_release_date, entre otros
- Entre las funciones candidatas están
- Para la entrada de ejemplo
What's a really fast-paced game with multiplayer that you like to play?, la salida esperada esrequest(has_multiplayer[yes], specifier[fast-paced]) - Los modelos generales no seguían bien el formato de salida deseado, y por el contexto de entrada largo el tiempo de procesamiento de entrada terminaba siendo mayor que el de generación de salida
- Esta tarea se basa más en reconocimiento de patrones y comprensión básica del lenguaje que en razonamiento lógico complejo
- Es una tarea grounded donde toda la información necesaria está incluida en la entrada
- El hecho de que los prompts few-shot ayuden se tomó como señal de que también era posible mejorar con ajuste fino en modelos pequeños de Llama-2
Evaluación y resultados de ViGGO
- La evaluación no usó solo coincidencia exacta de caracteres
- Verifica si la función de salida es correcta
- Verifica si los tipos de atributos son correctos
- Verifica si los atributos dentro de la función siguen el orden de prioridad definido
- En modelos instruction-following como GPT y Llama-2-chat, como la regla de orden de atributos estaba explícita en el prompt, se evaluó bajo la condición de que debían respetarla
- Para acelerar la evaluación, se usaron juntos la API de batch inference de Ray y Aviary de Anyscale
- Encadenando la generación del LLM y el postprocesamiento, y distribuyéndolos en varias máquinas
- Los modelos 7B y 13B mejoraron notablemente su precisión tras el ajuste fino
- En GPT-4, la precisión caía mucho al incluir en la evaluación la prioridad de atributos
- Los modelos ajustados finamente siempre respetaron esa prioridad, y su precisión no cambió al añadir esa restricción
- Los resultados de ViGGO muestran que el ajuste fino puede ser un medio estable y eficiente para tareas que requieren formatos estructurados
- No se trata solo de encajar con regex o formato JSON, sino de decidir qué argumentos incluir y respetar también el orden de los argumentos incluidos
- Como los resultados se obtuvieron con modelos 7B·13B, el costo de serving podría ser menor que invocar un endpoint de GPT-4
Generación de SQL: crear consultas desde lenguaje natural y contexto de tablas
- La tarea de generación de SQL consiste en recibir una consulta en lenguaje natural y una sentencia SQL
CREATE TABLE, y producir una consulta SQL ejecutable - El dataset usado, b-mc2/sql-create-context, es un dataset de Hugging Face que combina WikiSQL y Spider
- Cada datapoint consta de una consulta en lenguaje natural, una sentencia SQL
CREATE TABLEy la consulta SQL correspondiente - En total contiene 78,577 datapoints
- Cada datapoint consta de una consulta en lenguaje natural, una sentencia SQL
- El dataset tenía problemas en el SQL de referencia
- En
CREATE TABLE, muchos atributos enteros aparecían comoVARCHAR, pero en las consultas SQL se trataban como enteros - Se eliminaron todas las consultas SQL que asumían atributos enteros, reduciendo el dataset de unas 70k a 45k
- En
- Esta tarea también es adecuada para ajuste fino porque convierte lenguaje natural en una representación estructurada, en este caso SQL
- A diferencia de ViGGO, aquí puede haber varias consultas SQL distintas que produzcan el resultado correcto, por lo que la tarea es más ambigua
Evaluación y resultados de SQL
- Evaluar generación de SQL con simple comparación de cadenas no es apropiado
- La comparación carácter por carácter puede producir muchos falsos negativos
- Comparar AST también puede ser sensible a factores como el orden de nombres de variables
- El método más confiable es ejecutar el código sobre un dataset sintético y comparar si la salida coincide
- En el experimento se usó el endpoint GPT-3.5 de OpenAI para generar tablas sintéticas para pruebas unitarias sobre cientos de ejemplos
- GPT-3.5 veía la pregunta, el esquema de tabla y la respuesta correcta, y generaba una tabla sintética de 10 datapoints
- Luego se comparaban los resultados ejecutando el SQL de referencia y el del modelo con
sqlglot.executor.execute
- Para validar la calidad de las tablas generadas por GPT-3.5, primero se ejecutó el SQL correcto
- Si la tabla de resultados quedaba vacía o tenía la misma longitud que la tabla original, ese ejemplo se descartaba
- En este proceso se filtró alrededor del 50% de las tablas sintéticas generadas por GPT
- Llama-2 7B y 13B ajustados finamente obtuvieron mejor rendimiento que 70B-chat y GPT-4
- Un error común en los modelos chat de Llama era no colocar de forma consistente el SQL dentro de etiquetas
<SQL>, pese a que el prompt lo pedía - Este problema era más frecuente en los modelos chat 7B·13B que en 70B
- Un error común en los modelos chat de Llama era no colocar de forma consistente el SQL dentro de etiquetas
- Algunas consultas en lenguaje natural del dataset no estaban escritas en inglés perfecto, y ese ruido pudo haber afectado los resultados de GPT-4
- Los modelos ajustados finamente se adaptaron rápidamente a esas rarezas del dataset
GSM8k: razonamiento matemático más difícil que aprender estructura
- GSM8k es un benchmark académico estándar para evaluar razonamiento y comprensión matemática
- Mientras que las dos tareas anteriores trataban principalmente de aprendizaje de estructura, GSM8k sirve para ver cuánto puede mejorar el modelo en el proceso de razonamiento necesario para resolver problemas matemáticos
- Un problema de ejemplo pregunta por el total vendido si en abril se vendieron 48 unidades y en mayo la mitad de eso, y la respuesta correcta termina en formato
#### 72junto con cálculos intermedios - Los LLM actuales no suelen calcular internamente la respuesta final y emitirla directamente; necesitan generar parte del proceso de pensamiento en la salida para que los siguientes tokens puedan apoyarse en una secuencia lógica
- Esta tarea requiere no solo cálculo simple, sino una chain of thought lógica que vaya de las premisas a conclusiones intermedias y finalmente a la respuesta
Método de evaluación y líneas base de GSM8k
- La evaluación requiere una forma confiable de extraer la respuesta final desde la salida del modelo
- Como los modelos de lenguaje generales no siempre respetan el formato de salida esperado, la evaluación automática puede ser difícil
- Para eso se usó la API de function calling de OpenAI
gpt-3.5-turbo-0613llamaba a la funciónreport_answerpara extraer la respuesta entera final desde la generación de otro modelo- Por ejemplo, aunque un modelo respondiera “The answer is four”, se podía parsear como
4
- Este método se validó probándolo sobre las respuestas correctas del dataset, pero tiene la desventaja de agregar costo de tokens de OpenAI a la evaluación
- Los modelos ajustados finamente aprendieron rápido el patrón de respuesta objetivo, así que incluso cuando fallaban, la estructura de salida era predecible
- La evaluación de los modelos ajustados se resolvió con la regex
#### {answer}, evitando el postprocesamiento con endpoints de OpenAI
- La evaluación de los modelos ajustados se resolvió con la regex
- Las líneas base fueron las siguientes
- Resultados de 8-shot prompting de modelos base preentrenados publicados en artículos
- Varias plantillas con prompt engineering aplicadas a las variantes chat-tuned de Llama-2, entrenadas por Meta con RLHF para funcionar como asistentes generales
Resultados de GSM8k y ajuste fino en dos etapas
- El ajuste fino de modelos base mejoró de forma consistente el rendimiento en GSM8k, pero no siempre produjo resultados claramente superiores a los modelos chat-tuned
- Es probable que los modelos chat hubieran visto ejemplos matemáticos durante el proceso de chat-tuning, por eso mostraban mayor precisión que los modelos base
- Poner prompts sobre el modelo ajustado no siempre dio mejores resultados que el modelo base
- Por ejemplo, Llama-2-70B-chat puede quedar por debajo de un modelo base con prompt de 8 ejemplos
- Los modelos ajustados finamente fueron consistentemente mejores que los modelos base con prompt de 8 ejemplos
- En términos de costo de serving, los modelos ajustados pueden tener ventaja
- Los enfoques basados en prompts agregan costo por tokens del prompt en cada solicitud
- En los modelos ajustados, en la práctica el costo refleja casi solo los tokens de la pregunta
- Como el dataset de entrenamiento de GSM8k tiene apenas unas 8k muestras, se consideró insuficiente para explotar todo el potencial de Llama-13B
- Un enfoque de dos etapas, ajustando primero Llama-13B base con MathQA y luego otra vez con GSM8k, logró mejoras adicionales
- Ajustar solo con GSM8k mejoró 10 puntos porcentuales frente al modelo base
- El ajuste en dos etapas con MathQA seguido de GSM8k agregó otros 10 puntos porcentuales sobre ese resultado, para una mejora total de 20 puntos frente al base
- MathQA consta de 30,000 pares de pregunta/respuesta, pero tiene más ruido y una estructura distinta a GSM8k
- La calidad de las respuestas es menor y la respuesta final tiene formato multiple choice
- Aun así, el ajuste fino en dos etapas resultó efectivo para mejorar el resultado final en GSM8k aprovechando MathQA
Criterios a observar en implementación real
- Los modelos cerrados como GPT-4 y Claude-2 son fuertes para prototipado y validación inicial de valor, pero no siempre bastan para operar apps LLM en producción
- El ajuste fino de LLM para tareas nicho puede aportar valor no solo en privacidad, sino también en latencia, costo y calidad
- En los ejemplos de ViGGO y SQL, incluso superó a GPT-4 en calidad
- En ajuste fino, el foco importante no está tanto en los detalles de implementación de infraestructura como en recopilar datos y construir el pipeline de evaluación
- El pipeline de evaluación es la base para comparar, según los requisitos del negocio, el trade-off entre distintas soluciones
- Los experimentos se realizaron usando la plataforma de fine-tuning y serving de Anyscale y Anyscale Endpoints
- El mismo proceso puede repetirse con datos propios y en la nube propia, ya que está construido sobre las soluciones de fine-tuning y serving de Anyscale sobre Ray
1 comentarios
Opiniones en Hacker News
Hace unas semanas, en un livestream de programación, cubrí bastante cómo hacer fine-tuning de Llama 2 con un dataset propio, y lo hice en una sola GPU de Colab.
En mi caso, el dataset era mi propio código.
Fine-tuning Llama stream: https://www.youtube.com/watch?v=TYgtG2Th6fI&t=2282s
También tengo algunas sesiones más de fine-tuning con QLoRA, donde explico los conceptos desde la perspectiva de un ingeniero de software con 8 años de experiencia que recientemente pasó a machine learning y aprendió por su cuenta.
QloRa fine-tuning stream: https://www.youtube.com/watch?v=LitybCiLhSc&t=4584s
Intento explicarlo de la forma más simple posible, tanto para mis proyectos personales como para una startup basada en IA en la que estoy trabajando actualmente. Una serie sobre hacer fine-tuning del LLM más pequeño para desarrollo web también parece estar teniendo buena recepción; llevo alrededor de un mes haciendo streams y planeo subir mucho más contenido.
Tampoco entiendo bien el enfoque de tener modelos fine-tuned separados. Me pregunto si se necesitan un LLM para Terraform, un LLM para SQL y un LLM para Python por separado, o si basta con un solo LLM de “código”.
Se necesitan demasiados detalles de implementación, así que no es muy accesible salvo que sea un caso de uso significativo. Parece que privateGPT va avanzando lentamente hacia ese punto.
Es una parte que muchos otros tutoriales se saltan bastante. En particular, me interesa cómo prepararlos según distintos objetivos, como seguridad o precisión.
Estoy teniendo el mismo problema con Llama 2. Es casi imposible hacer que emita solo el texto que quiero; siempre agrega algo antes o después de la respuesta.
Me pregunto si hay alguna técnica de prompting para corregir esto.
airoboros admite un token PLAINFORMAT para evitar backticks, explicaciones, etc., y hacer que emita solo código.
https://huggingface.co/TheBloke/airoboros-l2-70B-GPT4-2.0-GG...
Si necesitas garantía, lo mejor es hacer fine-tuning con un dataset pequeño, de alrededor de 1.000 ejemplos, y mejorar desde ahí.
Mi caso de uso era una tarea simple de extraer/sintetizar información de texto, más que escritura creativa. El modelo base puede no encajar bien con todas las tareas.
contento dentro de JSON.Si es JSON, puedes identificar el inicio y el final, así que puedes eliminar todo lo que quede fuera del JSON.
Me alegra ver un artículo así. Ha habido muchísima discusión en línea sobre personalización de modelos, y este artículo elimina bastante bien el ruido.
También me gusta la metodología de evaluación, y el texto parece estar bien escrito.
Me parece raro que LoRA y el entrenamiento cuantizado no se traten con más seriedad. Son mucho más baratos, toman menos tiempo y hay bastante evidencia de que funcionan bien.
No creo que deban quedar relegados como una opción secundaria para probar más adelante.
Me alegra ver que una tarea parecida a NER haya obtenido el mejor rendimiento. Justo estaba por hacer una prueba similar para compararla con un modelo BERT fine-tuned.
Me pregunto cuál fue el costo de entrenamiento de esta tarea.
Podríamos haber reducido el tamaño de bloque, pero era más fácil dejar el código sin cambios. El 7B tardó unos 15 minutos por época en 16xA10G, y el 13B unos 25 minutos. Por lo tanto, el costo on-demand por época es de aproximadamente $7,2 para 7B y $12 para 13B. Estos valores consideran solo el tiempo usado para entrenar, sin incluir el tiempo de arranque/apagado del clúster.
Dicen que usaron 16xA10G para 7B y 13B, y 32xA10G para 70B, distribuidas en 4 instancias g5.48xlarge. Con Ray no hace falta conseguir A100 para hacer fine-tuning de parámetros completos de estos modelos, y repiten el mismo proceso para cada tarea. En el dataset GSM8k muestran una ejecución de ejemplo con longitud de contexto 512 y 3,7 millones de tokens efectivos por época.
Dicen que entrenaron hasta 10 épocas y eligieron el checkpoint con la perplejidad mínima en el conjunto de validación.
Una dificultad es que, para crear un dataset personalizado lo suficientemente grande, necesitas algo como un pequeño ejército de personas o un modelo existente muy potente.
Al final probablemente haya que usar OpenAI, pero generar material de entrenamiento para otro modelo con OpenAI viola sus términos. Me pregunto si esto alguna vez llegó a una demanda. ¿La gente simplemente lo considera injusto y lo ignora?
Últimamente veo más ejemplos de NER, y me pregunto por qué no usan spaCy para ese tipo de tareas.
Trabajo en Anyscale.
Como este blog parece haber recibido buen interés, planeamos incluirlo en Ray Summit: https://raysummit.anyscale.com/agenda
Si tienen ideas sobre qué tipo de contenido les gustaría ver más en Ray Summit, me gustaría escucharlas.
Dice que, con 3,5 millones de tokens, 7B tarda unos 14 minutos por época y 13B unos 26 minutos por época.
También dice que tanto 7B como 13B requieren al menos 1xg5.16xlarge como nodo head y 15xg5.4xlarge como nodos worker; me pregunto cuánto costaría eso en AWS.
Si lo corres en us-east-1, puedes estimar unos $30 por hora.
https://instances.vantage.sh/?selected=g5.16xlarge,g5.4xlarg...
Me pregunto si se puede hacer fine-tuning local de Llama-2 en un M1 Ultra de 64 GB. Casi todo lo que encuentro es en la nube o usando Nvidia CUDA en Linux, así que sería bueno tener material de referencia.
Para entrenamiento, pienso comprar algunos créditos de RunPod, y creo que se puede hacer por unas decenas de dólares.