Logo de CursoIA
CursoIA Curso de Inteligencia Artificial

Modelos de lenguaje, Unidad 11 — Límites de los modelos pre-transformer, 11.2 — Por qué surge el Transformer ·

Cambio de paradigma

Cambio de paradigma: Por qué surge el Transformer

Introducción

Los modelos pre-transformer, como los n-gram y los modelos recurrentes (RNNs), han sido la base para muchas aplicaciones de procesamiento del lenguaje natural (NLP) durante décadas. Sin embargo, estos modelos tienen limitaciones inherentes que los hacen inadecuados para tareas más complejas o escalables. Es en este contexto donde el Transformer, introducido por Vaswani et al. en 2017, revoluciona la manera de abordar problemas de NLP, ofreciendo una nueva arquitectura que supera las fallas del paradigma RNN.

Explicación principal con ejemplos

Procesamiento secuencial vs. paralelismo

Una de las limitaciones más notables de los modelos RNN es su dependencia del procesamiento secuencial. Cada paso en una RNN se basa en el estado oculto generado por el paso anterior, lo que resulta en un coste computacional exponencial para largas secuencias.

El Transformer, en contraste, utiliza mecanismos de atención (attention mechanisms) y arquitectura paralela. Esto permite procesar toda la secuencia de entrada simultáneamente sin sacrificar precisión. Este cambio es crucial, especialmente cuando se trabaja con grandes corpora de texto o datos largos.

Ejemplo: Procesamiento paralelo en Transformer

def multi_head_attention(q, k, v):
    """
    Implementación simplificada de la atención multicanal (multi-head attention).
    
    :param q: Tensor de queries, shape = [batch_size, num_heads, seq_length, d_k]
    :param k: Tensor de keys, shape = [batch_size, num_heads, seq_length, d_k]
    :param v: Tensor de values, shape = [batch_size, num_heads, seq_length, d_v]
    
    :return: Salida del mecanismo de atención multicanal.
    """
    # Calcular la atención
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, v)
    
    return output

# Ejemplo de uso
q = torch.randn(batch_size, num_heads, seq_length, d_k)
k = torch.randn(batch_size, num_heads, seq_length, d_k)
v = torch.randn(batch_size, num_heads, seq_length, d_v)

output = multi_head_attention(q, k, v)

Atención al contexto completo

Los modelos RNNs limitan su capacidad para capturar dependencias largas en las secuencias. La memoria limitada y la necesidad de procesar cada elemento de la secuencia afecta directamente a su capacidad para entender el contexto global.

El Transformer resuelve esto mediante mecanismos de atención que permiten a un token considerar todos los demás tokens simultáneamente, no solo aquellos cercanos en la secuencia. Esto es especialmente útil para tareas que requieren una comprensión profunda del contexto, como traducción o generación condicional.

Ejemplo: Atención global en Transformer

def positional_encoding(position, d_model):
    """
    Codificación posicional.
    
    :param position: Posición en la secuencia (0 <= position < max_seq_length)
    :param d_model: Dimensión del modelo
    
    :return: Tensor de codificación posicional con shape = [1, position, d_model]
    """
    angle_rads = torch.arange(position)[:, None] * (1 / d_model)**(2 * (torch.arange(d_model)[::2]) / d_model)
    angle_rads[0::2] = torch.sin(angle_rads[0::2])
    angle_rads[1::2] = torch.cos(angle_rads[1::2])

    return angle_rads

# Ejemplo de uso
max_seq_length, d_model = 512, 512
positional_encodings = positional_encoding(max_seq_length, d_model)

Errores típicos / trampas

Problema del gradiente en RNNs

Un problema conocido en los modelos RNN es el "problema de la desaparición del gradiente" o "explosión del gradiente", donde las señales se debilitan con el paso del tiempo, lo que dificulta el entrenamiento de capas profundas.

Falta de paralelismo en RNNs

Los modelos RNN son intrínsecamente secuenciales y no pueden aprovechar la potencia de cálculo moderna. Esto limita su escalabilidad para tareas con datos largos o complejos.

Problema de "token de padding" en RNNs

En los modelos basados en RNN, el uso del token de padding (como &lt;pad&gt;) puede introducir sesgos y distorsiones en la predicción, especialmente si no se manejan adecuadamente durante el entrenamiento o la inferencia.

Checklist accionable

  1. Entender las limitaciones de los modelos RNN: Familiarízate con la desaparición del gradiente y cómo afecta a la capacidad de estos modelos para aprender dependencias largas.
  2. Implementar atención multicanal en tu modelo: Utiliza mecanismos de atención para permitir que tus modelos consideren el contexto completo sin sacrificar el procesamiento paralelo.
  3. Incorporar codificación posicional: Asegúrate de que tus tokens tengan una representación espacial que refleje su posición en la secuencia.
  4. Manejar adecuadamente los tokens de padding: Establece estrategias claras para manejar estos tokens durante el entrenamiento y la inferencia.
  5. Experimentar con arquitecturas Transformer: Explora cómo puedes adaptar o integrar arquitecturas Transformer en tus soluciones actuales.

Siguientes pasos

  • Explorar modelos RNN avanzados: Aprende sobre variantes de RNN como LSTM y GRU, que intentan abordar algunos de los problemas inherentes a las RNN simples.
  • Estudiar la implementación del Transformer en detalle: Familiarízate con el código y la lógica detrás de la arquitectura Transformer para entender mejor cómo funciona.
  • Probar modelos pre-transformer vs. Transformer en tareas específicas: Implementa ambos tipos de modelos en una tarea que te interese y evalúa sus desempeños.

El cambio a un paradigma basado en la atención y el procesamiento paralelo, representado por arquitecturas como el Transformer, ha abierto nuevas posibilidades para el NLP. Aprovechar estas oportunidades requiere una comprensión profunda no solo de las limitaciones de los modelos RNNs, sino también de cómo se pueden abordar y superar ellas.

Contacto

Indica tu objetivo (ChatGPT, RAG, agentes, automatización) y tu stack (web/backend).