1 puntos por GN⁺ 2024-07-14 | Aún no hay comentarios. | Compartir por WhatsApp
  • AlphaFold3 busca predecir solo a partir de la secuencia complejos que van más allá de una sola proteína e incluyen proteínas, ácidos nucleicos y moléculas pequeñas, por lo que la representación de entrada y la tokenización son mucho más complejas que en AF2
  • La entrada se divide en representación single/pair a nivel de token, representación a nivel de átomo, MSA y templates; los aminoácidos y nucleótidos estándar se tratan como 1 token, mientras que los residuos no estándar y otras moléculas se procesan como 1 token por átomo
  • El trunk de aprendizaje de representaciones mejora de forma iterativa la representación single s y la representación pair z mediante el módulo de templates, el módulo MSA y Pairformer, con pair-bias attention, operaciones triangle y recycling
  • La predicción estructural usa un modelo de difusión condicional sobre coordenadas atómicas en lugar de la Invariant Point Attention de AF2, y genera actualizaciones de coordenadas para todos los átomos mediante aumentos de rotación/traslación y denoising
  • El entrenamiento combina distogram, diffusion y confidence loss, y con cross-distillation usando resultados de AF2 y AF-Multimer vuelve a aprender incluso representaciones unfolded en regiones de baja confianza

Alcance de entrada y pipeline completo de AlphaFold3

  • El objetivo de AlphaFold3 no es solo predecir secuencias individuales de proteínas como AF2 ni limitarse a complejos proteicos como AF-Multimer, sino predecir solo a partir de la secuencia estructuras donde una proteína está unida opcionalmente a otras proteínas, ácidos nucleicos y moléculas pequeñas
  • El significado de “token” cambia según el tipo de entrada
    • Proteína: 1 aminoácido estándar = 1 token
    • DNA/RNA: 1 nucleótido estándar = 1 token
    • Aminoácidos/nucleótidos no estándar: 1 átomo = 1 token
    • Otras moléculas: 1 átomo = 1 token
  • Una proteína de 35 aminoácidos estándar puede tener en realidad más de 600 átomos, pero se representa con 35 tokens, mientras que un ligand de 35 átomos se representa con 35 tokens
  • El modelo se compone en líneas generales de tres etapas
    • Input Preparation: convierte la secuencia de entrada del usuario y las secuencias/estructuras relacionadas encontradas en tensores numéricos
    • Representation Learning: actualiza la representación single y la representación pair con varias variantes de atención
    • Structure Prediction: predice la estructura mediante difusión condicional
  • El complejo proteico se almacena principalmente en dos representaciones
    • single representation: representa todos los tokens del complejo en sí
    • pair representation: representa relaciones como distancia e interacciones potenciales entre todos los pares de tokens
  • Las principales dimensiones de canal son c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384

Preparación de entrada: cómo convertir la secuencia en 6 tensores

  • La entrada proporcionada por el usuario se convierte en 6 tensores que entran al trunk del modelo
    • s: token-level single representation
    • z: token-level pair representation
    • q: atom-level single representation
    • p: atom-level pair representation
    • m: MSA representation
    • t: template representation
  • Búsqueda de MSA y templates

    • AF3 busca secuencias similares para proteínas y RNA, las organiza como MSA e incluye estructuras relacionadas como template
    • Un MSA alinea secuencias de proteínas similares encontradas en múltiples especies para aportar al modelo patrones de conservación en posiciones específicas y correlaciones entre cambios en distintas posiciones
    • Las estructuras conocidas de proteínas similares se usan para estimar la estructura de la proteína query, como en homology modeling
    • La búsqueda no incluye entrenamiento y usa métodos basados en HMM
    • Se consultan varias bases de datos de proteínas y RNA con jackhmmer, HHBlits y nhmmer, y se buscan secuencias similares en el Protein Data Bank con hmmsearch
    • El tamaño del MSA está limitado a N_MSA < 2^14 por la complejidad computacional
    • En cada chain de proteína se seleccionan estructuras de alta calidad y se muestrean hasta 4 como template
    • Frente a AF-Multimer, el nuevo elemento añadido en la búsqueda es que las secuencias de RNA también se incluyen entre los objetivos de búsqueda
  • Forma de representación de templates

    • En la estructura 3D del template se calcula la distancia euclidiana entre cada par de tokens
    • Para tokens con múltiples átomos se usa un “center atom” representativo
      • Aminoácido: átomo
      • Nucleótido estándar: átomo C1'
    • Los valores de distancia no se usan como valores continuos, sino que se discretizan en un distogram
      • 38 bins desde 3.15Å hasta 50.75Å
      • 1 bin adicional para distancias mayores
    • Al distogram se le añaden información de chain, si ese token está resolved en la crystal structure, e información de distancia local dentro de cada aminoácido
    • La matriz de template se enmascara para ver solo distancias dentro de la misma chain, y la selección de templates no busca obtener información de interacciones entre chains

Representación a nivel atómico y Atom Transformer

  • reference conformer y representación a nivel atómico

    • Para crear la representación single a nivel atómico q, se calcula un reference conformer para cada aminoácido, nucleótido y ligando.
    • Un conformer es una disposición atómica 3D de una molécula generada al muestrear rotaciones alrededor de enlaces simples.
    • Los aminoácidos estándar usan conformers de baja energía obtenidos por lookup, y las moléculas pequeñas generan conformers 3D con RDKit’s ETKDGv3.
    • Al combinar la posición relativa del conformer, la carga atómica, el número atómico, los identificadores y otros datos, se crea la representación single a nivel atómico c.
    • c inicializa la representación pair a nivel atómico p, y se usa la máscara v para que solo contenga distancias entre átomos calculadas a partir del reference conformer.
    • q comienza como una copia de c y luego se actualiza en el Atom Transformer.
  • Rol del Atom Transformer

    • Atom Transformer es un módulo que realiza atención a nivel atómico y actualiza q usando p y la representación original c.
    • c no se actualiza y se usa como una especie de conexión residual hacia la representación inicial.
    • La estructura básica incluye LayerNorm, atención y una transición MLP, de forma similar a un transformer, pero cada etapa se ajusta con entradas adicionales de c y p.
  • Adaptive LayerNorm

    • Adaptive LayerNorm, en lugar de aprender gamma y beta fijos, genera gamma y beta a partir de una entrada auxiliar.
    • En Atom Transformer, el objetivo del reescalado es q, y los parámetros de reescalado se predicen a partir de la entrada auxiliar c.
  • Attention with Pair Bias

    • La atención a nivel atómico con pair bias es una extensión de self-attention.
    • Query, key y value salen todos de la representación single q, pero después del producto punto query-key se suma como bias una proyección lineal de la representación pair p.
    • La información fluye de la representación pair hacia q, pero en esta etapa no se actualiza p usando información de q.
    • Un gate creado al pasar una proyección adicional por una sigmoide se multiplica por el resultado de la atención y controla qué información se mantiene en el residual stream.
    • Como la cantidad de átomos puede ser mucho mayor que la cantidad de tokens, se usa Sequence-local atom attention en lugar de atención completa.
    • Un grupo local de 32 átomos puede atender a otros 128 átomos.
  • Conditioned Gating y Transition

    • Conditioned Gating aplica a los datos un gate generado a partir de la matriz single original a nivel atómico c.
    • Conditioned Transition corresponde al MLP del transformer y se llama conditioned porque tanto Adaptive LayerNorm como Conditional Gating dependen de c.
    • AF3 usa SwiGLU en el bloque de transición en lugar de ReLU.
    • La transición basada en ReLU de AF2 tiene una estructura de up-projection 4x, ReLU y down-projection.
    • En el SwiGLU de AF3, se aplica la no linealidad swish a una de dos up-projections, luego se multiplican y después se hace down-project.

Agregación de representaciones atómicas en representaciones de token

  • Como la etapa de aprendizaje de representaciones después opera a nivel de token, la representación a nivel atómico se agrega en una representación a nivel de token.
  • Después de proyectar la representación a nivel atómico a una dimensión mayor, se toma el promedio de los átomos que pertenecen al mismo token.
  • Esta agregación por promedio se aplica cuando varios átomos están conectados a un mismo token, como en aminoácidos estándar y nucleótidos, mientras que las entradas de 1 token por átomo se mantienen tal cual.
  • La entrada single a nivel de token también incorpora estadísticas obtenidas del MSA.
    • tipo de aminoácido
    • distribución de aminoácidos del MSA en esa posición
    • media de deleciones de ese token
  • En tokens sin MSA, como los átomos de ligandos, estos valores quedan en 0.
  • El s_inputs creado de esta forma pasa por una proyección para convertirse en s_init, y luego se actualiza durante la etapa de aprendizaje de representaciones.
  • La representación pair z_init es un tensor tridimensional que almacena relaciones entre pares de tokens, y cada z_i,j es un vector de dimensión c_z=128.
  • La inicialización de z_i,j suma la proyección de s_i y s_j, el relative positional encoding y la información de enlaces entre tokens especificada por el usuario.

Aprendizaje de representaciones: Template, MSA, Pairformer

  • El aprendizaje de representaciones es el trunk que ocupa la mayor parte del cómputo del modelo, y su objetivo es mejorar la representación single a nivel de token s y la representación pair z.
  • La single sequence representation no se refiere solo a una única secuencia de proteína, sino a una secuencia formada al concatenar todos los átomos o tokens de la estructura.
  • Template Module

    • Cada template pasa por una proyección lineal y luego se suma con una proyección lineal de la representación pair z.
    • La matriz combinada pasa por un Pairformer Stack.
    • Los resultados de varios templates se promedian y luego vuelven a pasar por una capa lineal.
    • La última capa lineal usa ReLU, uno de los pocos lugares en AF3 donde ReLU se usa como no linealidad.
  • MSA Module

    • El MSA Module es muy similar al Evoformer de AF2 y mejora al mismo tiempo la representación MSA m y la representación pair z.
    • En lugar de usar todas las filas del MSA, primero se hace subsampling y luego se suma al MSA una proyección de la representación single.
    • Outer Product Mean es la operación que incorpora información del MSA en la representación pair.
      • para cada índice de token i,j, se calcula el outer product de m_s,i y m_s,j para todas las secuencias evolutivas
      • luego se promedia sobre toda la secuencia, se aplana y se proyecta para sumarlo a z_i,j
      • es el único punto del modelo donde se comparte información entre secuencias evolutivas
    • Row-wise gated self-attention using only pair bias actualiza el MSA usando la representación pair.
      • en lugar de crear el score de atención con query y key, se proyecta la representación pair z a una matriz y se usa como score de atención entre tokens
      • como se aplica de forma independiente a cada fila del MSA, en esta etapa no se comparte información entre secuencias evolutivas
    • El final del módulo MSA vuelve a actualizar la representación pair con triangle update y triangle attention.

Pairformer y operaciones triangle

  • Después de actualizar z con Template y MSA, ya no se vuelven a usar template ni MSA, y solo s y z se ingresan a Pairformer
  • Pairformer genera los s_trunk y z_trunk finales mediante la repetición de 48 bloques
  • Intuición de las operaciones triangle

    • triangle update y triangle attention son estructuras diseñadas para reflejar en el modelo la intuición de la desigualdad triangular
    • Aunque z_i,j del pair tensor no es en sí la distancia física, sí contiene la relación entre los tokens i y j, por lo que se actualiza para que las tres relaciones i-j, j-k e i-k sean consistentes entre sí
    • La desigualdad triangular no se impone directamente dentro del modelo, sino que se induce actualizando z_i,j al considerar todos los tripletes (i,j,k)
    • z puede verse como una directed adjacency matrix, por lo que las direcciones de outgoing edge e incoming edge se procesan por separado
  • Triangle Updates

    • En el outgoing update, cada z_i,j se actualiza usando otro elemento de la misma fila, z_i,k, y la tercera arista z_j,k
    • A nivel de implementación, se crean tres proyecciones de z: a, b y g; luego se realiza la multiplicación element-wise entre la fila i y la fila j, se suma sobre k y después se aplica la gate g
    • El incoming update cambia filas por columnas: z_i,j se actualiza a partir de otros elementos de la misma columna, z_k,j y z_k,i
  • Triangle Attention

    • triangle attention es una forma que añade el principio triangle a axial attention, que aplica attention independiente a filas y columnas de una matriz 2D
    • En el caso de “starting node”, a la comparación query-key entre z_i,j y z_i,k se le suma z_j,k como bias
    • En el caso de “ending node”, opera con base en columnas, y el attention score entre z_i,j y z_k,i recibe z_k,j como bias
  • Single Attention with Pair Bias

    • Después del paso triangle y del transition block, la representación single s se actualiza mediante single attention with pair bias usando la representación pair actualizada z
    • Como opera a nivel token, usa full attention en lugar de la block-wise sparse attention que se usaba a nivel átomo

Predicción de estructura: denoising de coordenadas atómicas por difusión

  • Mecanismo básico del modelo de difusión

    • AF3 realiza la predicción final de estructura mediante atom-level diffusion
    • Un diffusion model se entrena agregando random noise a los datos reales de forma gradual y haciendo que el modelo prediga qué noise se añadió
    • En inference, se parte de random noise puro y, en cada step, se elimina el noise predicho por el modelo para generar un datapoint denoised
    • La difusión condicional toma como entrada la generación ruidosa actual, la representación del timestep actual y el vector de condición, para generar un resultado acorde con esa condición
    • En AF3, el objetivo del denoising es la matriz x que contiene las coordenadas x,y,z de todos los átomos
  • Augmentación de rotación y traslación en lugar de IPA de AF2

    • AF3 no usa Invariant Point Attention de AF2, sino que en cada timestep rota y traslada aleatoriamente todo el complejo que está prediciendo
    • Esta augmentación hace que el modelo aprenda que cualquier rotación y traslación sigue representando la misma estructura, y es un enfoque más simple que la IPA de AF2
    • La rotación se aplica alrededor del promedio de las coordenadas de todos los átomos de la generación actual, y la translation se samplea como una Gaussian N(0,1) en cada dimensión
    • También se añade un pequeño noise a las coordenadas para inducir generaciones más diversas
    • En inference, varias generaciones pueden puntuarse con el confidence head y se puede devolver la que tenga la puntuación más alta
  • Cuatro etapas del Diffusion Module

    • Cada step de denoising usa varias representaciones de conditioning
      • salidas del trunk: s_trunk, z_trunk
      • representaciones iniciales creadas por el input embedder: s_inputs, c_inputs
    • El proceso de difusión se compone de cuatro etapas y va y viene entre los espacios de token y átomo
        1. preparación del tensor de conditioning a nivel token
        1. preparación del tensor de conditioning a nivel átomo, aplicación de Atom Transformer y agregación al nivel token
        1. aplicación de attention a nivel token
        1. predicción del noise update por átomo mediante attention a nivel átomo
    • En el conditioning a nivel token, se combinan z_trunk y el relative positional encoding, y luego se pasa por un transition block
    • En la representación single, se combinan s_inputs y s_trunk, y se añade un Fourier embedding según el diffusion timestep
    • En la etapa a nivel átomo, las representaciones iniciales c y p se actualizan con la representación actual a nivel token, y las coordenadas actuales x se escalan con la data variance para crear la coordenada adimensional r
    • En la última etapa a nivel átomo, una linear layer mapea q a R^3 para generar el coordinate update r_update de todos los átomos
    • El update se reescala a x_update considerando la data variance y el noise schedule, y luego se aplica a las coordenadas actuales x_l

Función de pérdida y confidence head

  • La loss total es una suma ponderada de tres términos

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogram evalúa la precisión del distograma predicho a nivel de token
    • Al crear coordenadas de token a partir de coordenadas atómicas, se usan las coordenadas del átomo central de cada token
    • La distancia del distograma se trata como un valor categórico, y el distograma predicho se compara con el real mediante entropía cruzada
  • L_diffusion

    • L_diffusion es una suma ponderada de varios términos sobre la posición de los átomos
    • L_MSE calcula el error cuadrático medio entre posiciones para todos los átomos, no solo el átomo central, y los átomos de DNA, RNA y ligandos reciben mayor peso
    • L_bond es un término MSE adicional para mejorar la precisión de la longitud de enlace en pares de átomos incluidos en enlaces proteína-ligando
    • En la etapa inicial de entrenamiento, α_bond=0, así que se introduce más adelante
    • L_smooth_LDDT es una loss que vuelve la precisión local de distancia más suave y diferenciable
      • Se usan cuatro umbrales: 4Å, 2Å, 1Å y 0.5Å
      • Los pares de átomos de nucleótidos se ignoran si están a más de 30Å
      • Los pares de átomos de proteína o ligando se ignoran si están a más de 15Å
  • L_confidence

    • L_confidence no busca aumentar directamente la precisión estructural, sino entrenar al modelo para que estime la precisión de sus propias predicciones
    • Está compuesto por losses correspondientes a cuatro métricas de confianza
      • pLDDT: precisión local de distancia para átomos cercanos
      • PAE: error de alineación predicho para pares de tokens
      • PDE: error de distancia predicho entre pares de tokens
      • experimentally resolved prediction: predicción de si cada átomo fue resuelto en la estructura experimental
    • Aunque la estructura predicha sea imprecisa y el PAE sea alto, si el modelo también predice un PAE alto, esa loss de PAE puede disminuir
    • La predicción de confianza se genera en una etapa intermedia de la difusión
    • El gradiente de la loss de confianza actualiza solo el head de predicción de confianza y no afecta al resto del modelo

Técnicas adicionales de entrenamiento y optimización

  • Recycling

    • AF3 usa reciclaje de pesos, igual que AF2
    • En lugar de hacer el modelo más profundo, reutiliza los mismos pesos varias veces para mejorar gradualmente la representación
    • La difusión también usa información de timestep durante la inferencia y reutiliza los mismos pesos en cada timestep, por lo que incorpora reciclaje de forma inherente
  • Cross-distillation

    • AF3 usa no solo datos sintéticos de entrenamiento generados por él mismo, sino también datos sintéticos creados por AF2 y AF-Multimer
    • Tras cambiar a generación basada en difusión, surgió el problema de que desaparecía la forma tipo “spaghetti” que en AF2 permitía distinguir visualmente regiones de baja confianza o desordenadas
    • Al incluir en los datos de entrenamiento de AF3 generaciones de AF2 y AF-Multimer, AF3 aprende a producir regiones desplegadas en zonas donde AF2 no estaba seguro
    • En el dataset de destilación se eliminan los ácidos nucleicos y las moléculas pequeñas que AF2 y AF-Multimer no pueden manejar
    • Después de que el modelo anterior genera la estructura predicha y se alinea con la original, las moléculas eliminadas se vuelven a añadir
    • Si las moléculas reinsertadas generan colisiones atómicas, se excluye toda la estructura para evitar que el modelo aprenda a permitir clashes
  • Cropping y etapas de entrenamiento

    • El modelo en sí no tiene un límite explícito para la longitud de la secuencia de entrada, pero varias operaciones crecen como N_tokens^3, aumentando los requisitos de memoria y cómputo
    • Para mayor eficiencia, las proteínas se recortan aleatoriamente
    • Como hay que modelar interacciones entre múltiples cadenas, el recorte debe incluir las cadenas en conjunto
    • Se usan tres métodos de cropping
      • contiguous cropping: selección de una secuencia continua de aminoácidos en cada cadena
      • spatial cropping: selección de aminoácidos según la distancia al átomo de referencia
      • spatial interface cropping: selección basada en la distancia a los átomos de la interfaz de unión
    • Un modelo entrenado con random crop de 384 también puede aplicarse a secuencias más largas, pero para mejorar su capacidad de procesarlas se hace fine-tuning repetido con longitudes de secuencia mayores
  • Clashing y batch size

    • La loss de AF3 no incluye una penalización por clash para átomos superpuestos
    • En teoría, el módulo estructural basado en difusión podría predecir dos átomos en la misma posición, pero después del entrenamiento ese problema es pequeño
    • Para el ranking de estructuras generadas sí se usa una penalización por clashing
    • El proceso de difusión parece complejo, pero su costo computacional es menor que el del trunk
    • Para mejorar la eficiencia del entrenamiento, el batch size se amplía después del trunk
    • Cada estructura de entrada pasa una vez por el embedding y el trunk, y luego se entrenan en paralelo 48 estructuras independientes aumentadas con datos

Diseño de AF3 desde la perspectiva de ML

  • Una estructura similar a Retrieval-Augmented Generation

    • La búsqueda de MSA y templates en AF3 tiene características similares al RAG de los modelos de lenguaje
    • En el campo de AlphaFold, el uso de templates estructurales ya se utilizaba desde mucho antes que el término RAG bajo el nombre de homology modeling
    • AF3 redujo el peso del procesamiento de MSA frente a AF2, pero sigue incluyendo MSA y templates
    • Algunos modelos de predicción de proteínas como ESMFold eliminan el retrieval y usan fully parametric inference
  • Pair-Bias Attention

    • Pair-Bias Attention, que era un componente principal de AF2, se usa de manera más amplia en AF3
    • query, key y value provienen de la misma source, pero al attention map se le suma un bias term que viene de otra source
    • Esta es una forma de compartir información más ligera que full cross-attention
    • Como la pair representation se parece de forma natural al attention map, esta estructura puede encajar bien en el modelado de proteínas
  • Reducción del self-supervised training

    • Los modelos de la familia ESM mostraron fortalezas al reemplazar el embedding de MSA con self-supervised pre-training
    • AF2 tenía una tarea adicional de predecir masked tokens del MSA, pero en AF3 fue eliminada
    • AF3 redujo el compute dedicado al procesamiento de MSA y no usa self-supervised language modeling pre-training para MSA
    • Una posible razón es que el massive pre-training fuera ineficiente en términos de uso de compute, que un módulo de MSA pequeño funcionara mejor que un pre-trained embedding, o que la combinación entre una estructura híbrida de atom-token mezclando aminoácidos, DNA/RNA y ligands con embeddings preentrenados no encajara bien
  • Mezcla de classification y regression

    • AF3, al igual que AF2, usa tanto MSE como binned classification loss
    • Una característica del classification loss es que, aunque solo se falle por un bin en el distogram, no se recibe crédito de manera distinta que cuando el error es mucho mayor
    • La base de esta decisión de diseño no está clara, pero es posible que el gradient fuera más estable que con varias pérdidas MSE
  • Elementos parecidos a una recurrent architecture

    • En AF3 hay muchos elementos que recuerdan más a una recurrent network que a un transformer convencional
    • El gating controla el flujo de información en el residual stream y es similar a las compuertas de LSTM o GRU
    • El recycling y la diffusion aplican repetidamente el mismo weight para mejorar gradualmente la predicción
    • De forma similar a adaptive compute time, las actualizaciones iterativas se relacionan con una estructura que puede aplicar más procesamiento a entradas difíciles
    • En las ablation de AF2 se mostró la importancia del recycling, pero casi no hubo discusión sobre la importancia del gating

Aún no hay comentarios.

Aún no hay comentarios.