Logo de CursoIA
CursoIA Curso de Inteligencia Artificial

Redes neuronales desde cero, Unidad 9 — Overfitting en redes neuronales, 9.1 — Alta capacidad y riesgos ·

Señales de sobreajuste

Señales de sobreajuste

Introducción

El sobreajuste, también conocido como overfitting, es uno de los problemas más comunes y perjudiciales que pueden surgir durante el entrenamiento de modelos de machine learning. Cuando un modelo se adapta demasiado a su conjunto de datos de entrenamiento, capturando incluso el ruido o la aleatoriedad en este, no sólo pierde su capacidad para generalizar a nuevos datos, sino que también puede tener un rendimiento inferior cuando se aplica al mundo real. En el caso de las redes neuronales (NN), esto significa que, si bien pueden alcanzar una precisión impresionante en los datos de entrenamiento, podrían fallar miserablemente en predicciones nuevas o no vistos.

Explicación principal con ejemplos

El sobreajuste ocurre cuando un modelo es demasiado complejo y ajusta a la perfección al ruido en lugar de a las verdaderas características del conjunto de datos. Esto puede llevar a una reducción significativa en el rendimiento del modelo en nuevos datos, ya que no se ha aprendido a generalizar.

Ejemplo con código

Vamos a ilustrar esto con un ejemplo simple usando una red neuronal para clasificar imágenes de dígitos escritos a mano (MNIST).

import tensorflow as tf
from tensorflow.keras import datasets, layers, models

# Cargar los datos MNIST
(train_images, train_labels), (_, _) = datasets.mnist.load_data()

# Preprocesar los datos
train_images = train_images / 255.0

# Crear una red neuronal con muchas capas ocultas para overfitting
model = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(1024, activation='relu'),
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(10)
])

# Compilar el modelo
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

# Entrenar el modelo (código cortado para brevedad)

Cuando se entrena este modelo con los datos de entrenamiento, puede alcanzar una precisión del 98% en los datos de entrenamiento. Sin embargo, cuando se evalúa sobre un conjunto de validación no visto, la precisión caerá significativamente.

Errores típicos / trampas

  1. Capas ocultas excesivamente profundas o anchas: Una arquitectura con demasiadas capas y neuronas puede capturar ruido en lugar de patrones útiles.
  2. Entrenamiento por demasiado tiempo: A menudo, los modelos overfitting a menudo se entrenan durante más épocas del necesario, adaptándose al ruido en los datos de entrenamiento.
  3. Parámetros inadecuados: La elección incorrecta de parámetros como el tipo de función de activación o la tasa de aprendizaje puede llevar a overfitting.

Checklist accionable

A continuación, se presentan algunas medidas que puedes tomar para prevenir el sobreajuste en tus modelos de red neuronal:

  1. Usar regularización: Aplica regularización L1 y/o L2 para penalizar la complejidad del modelo.
  2. Dropout: Introduce dropout durante el entrenamiento para evitar que las neuronas se vuelvan demasiado dependientes entre sí.
  3. Validación cruzada: Divide tu conjunto de datos en partes y valida cada parte con diferentes conjuntos de validación.
  4. Regularización en tiempo de inferencia (Early Stopping): Detén el entrenamiento cuando la precisión del conjunto de validación deja de mejorar.
  5. Usar un conjunto de validación adecuado: Asegúrate de que tu conjunto de validación sea representativo de los datos reales a los que se aplicará el modelo.

Cierre con "Siguientes pasos"

  • Explorar más técnicas de regularización: Incluye L1, L2 y otras variantes.
  • Prueba diferentes arquitecturas de red neuronal: Experimenta con menos capas o neuronas para ver si mejoran los resultados.
  • Analiza el rendimiento en una variedad de conjuntos de datos: Verifica cómo se comporta tu modelo en diferentes contextos.

El sobreajuste es un desafío significativo en la aplicación práctica del machine learning, pero con el uso de las técnicas adecuadas y un entendimiento profundo de las características de los modelos, puedes evitarlo y mejorar significativamente su rendimiento.

Contacto

Indica tu objetivo (ChatGPT, RAG, agentes, automatización) y tu stack (web/backend).