Logo de CursoIA
CursoIA Curso de Inteligencia Artificial

PyTorch desde cero, Unidad 9 — Regularización y control del sobreajuste, 9.2 — Técnicas de regularización ·

Weight decay

Weight decay: una técnica efectiva para controlar el sobreajuste

Introducción

El overfitting es uno de los desafíos más comunes que enfrentan los desarrolladores y científicos de datos en la construcción de modelos de aprendizaje profundo. Es particularmente problemático porque un modelo que sobrepasa su conjunto de entrenamiento a menudo no generaliza bien a datos nuevos o desconocidos. Weight decay, también conocido como regularización L2, es una técnica efectiva para mitigar este problema.

Weight decay funciona aplicando un término adicional en la función de pérdida del modelo. Este término penaliza a las capas ocultas con grandes pesos, lo que ayuda a prevenir el overfitting al reducir la complejidad del modelo y hacerlo más simple.

Explicación principal

La regularización L2 se aplica sumando un termo al coste total de la función de pérdida. Este término penaliza las capas ocultas con grandes pesos, haciendo que estos valores tiendan a ser menores durante el entrenamiento. La fórmula para el costo regularizado es:

\[ \text{Cost} = \text{Cost}_{\text{original}} + \lambda \sum_{i=1}^{n} w_i^2 \]

Donde:

  • $\text{Cost}_{\text{original}}$ es la función de pérdida original.
  • $\lambda$ (o weight_decay en PyTorch) controla el grado de regularización. Cuanto mayor sea $\lambda$, más se penalizarán los pesos grandes.
  • $w_i^2$ son los cuadrados de los pesos.

A continuación, un ejemplo sencillo de cómo implementar la regularización L2 usando PyTorch:

import torch
from torch import nn

# Definir el modelo
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)

# Crear un optimizador con weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

# Ejemplo de entrenamiento (solo muestra)
X = torch.randn(20, 10)  # Datos de entrada
y = torch.randn(20, 1)   # Etiquetas

for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = nn.MSELoss()(outputs, y)
    loss.backward()
    optimizer.step()

Errores típicos / trampas

A pesar de su efectividad, la regularización L2 puede llevar a varios errores comunes durante el uso:

  1. Selección inadecuada del valor de $\lambda$: Un valor muy pequeño no tendrá ningún efecto en la regularización, mientras que un valor muy grande puede hacer que los pesos se converjan hacia cero rápidamente.
  • Solución: Probar varios valores de $\lambda$ y evaluar el rendimiento en datos de validación.
  1. Influencia incorrecta en las capas ocultas: Si no se aplica regularización a todas las capas, pueden surgir imprevistos en la complejidad del modelo.
  • Solución: Aplicar regularización L2 a todas las capas que contengan pesos.
  1. Interacción con otros métodos de regularización: Si se combinan múltiples técnicas de regularización, es posible que surjan interacciones no deseadas entre ellas.
  • Solución: Verificar el comportamiento del modelo y ajustar cuidadosamente las configuraciones.

Checklist accionable

Aquí tienes una lista de puntos a considerar al implementar la regularización L2:

  1. Determina $\lambda$ adecuadamente: Prueba diferentes valores para encontrar uno que funcione bien.
  1. Aplica a todas las capas ocultas: Asegúrate de que todos los pesos están siendo penalizados.
  1. Monitorea el rendimiento en datos de validación: Verifica que la regularización no esté perjudicando la capacidad del modelo para generalizar.
  1. Ajusta otras configuraciones: Considera ajustar otros hiperparámetros como la tasa de aprendizaje y el número de épocas.
  1. Evalúa interacciones con otros métodos: Si se usan técnicas adicionales de regularización, evalúa su interacción con L2.

Siguientes pasos

  • Explorar más profundamente en PyTorch: Consulta la documentación oficial para entender mejor cómo ajustar los parámetros y aplicar regularización adecuadamente.
  • Prueba en diferentes conjuntos de datos: Verifica cómo se comporta la regularización L2 con varios conjuntos de datos y problemas.
  • Aprende sobre otras técnicas de regularización: Explora métodos como dropout y early stopping, que pueden complementar la regularización L2.

Implementar correctamente la regularización L2 puede mejorar significativamente el rendimiento de tus modelos. Con un enfoque cuidadoso y una comprensión sólida de cómo funciona esta técnica, podrás construir modelos más robustos y generalizables.

Contacto

Indica tu objetivo (ChatGPT, RAG, agentes, automatización) y tu stack (web/backend).