Source code for sdofmv2.core.basemodule
from loguru import logger as lgr_logger
import lightning.pytorch as pl
import torch
from transformers import get_cosine_schedule_with_warmup
[docs]
class BaseModule(pl.LightningModule):
"""A foundational PyTorch Lightning module for standardized training.
This base class handles the boilerplate configuration for optimizers
and learning rate schedulers. Other models in the pipeline should inherit
from this class and implement their specific `training_step` and
`validation_step` logic.
Args:
optimizer_dict (dict): Configuration dictionary for the optimizer.
Expected keys include "use" (e.g., "adamw", "sgd", "adam"),
"learning_rate", and "weight_decay".
scheduler_dict (dict): Configuration dictionary for the learning rate
scheduler. Expected keys include "use" (e.g., "cosine", "cosine_warmup",
"plateau", "exp"), "monitor" (metric to track), and any scheduler-specific
hyperparameters.
hyperparam_ignore (list[str], optional): List of parameter names to
exclude from Lightning's automatic hyperparameter saving. Defaults to [].
*args: Variable length argument list passed to `pl.LightningModule`.
**kwargs: Arbitrary keyword arguments passed to `pl.LightningModule`.
"""
def __init__(
self,
optimizer_dict,
scheduler_dict,
hyperparam_ignore=[],
# pass to pl.LightningModule
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.optimizer_dict = optimizer_dict
self.scheduler_dict = scheduler_dict
self.save_hyperparameters(ignore=hyperparam_ignore)
[docs]
def training_step(self, batch, batch_idx):
"""Perform a single training step.
Args:
batch: The training batch data.
batch_idx: The index of the current batch.
Raises:
NotImplementedError: Subclasses must implement this method.
"""
raise NotImplementedError
[docs]
def validation_step(self, batch, batch_idx):
"""Perform a single validation step.
Args:
batch: The validation batch data.
batch_idx: The index of the current batch.
Raises:
NotImplementedError: Subclasses must implement this method.
"""
raise NotImplementedError
[docs]
def configure_optimizers(self):
"""Configure optimizers and learning rate schedulers.
Returns:
Union[torch.optim.Optimizer, Dict]: Either a single optimizer or a dict
containing optimizer and lr_scheduler configuration.
"""
opt_type = self.optimizer_dict.get("use", "adamw")
lr = self.optimizer_dict.get("learning_rate", 1e-4)
weight_decay = self.optimizer_dict.get("weight_decay", 0.01)
lgr_logger.debug(f"Initial/Peak LR: {lr}")
lgr_logger.debug(f"Weight decay: {weight_decay}")
match opt_type:
case "adam":
optimizer = torch.optim.Adam(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
)
case "sgd":
optimizer = torch.optim.SGD(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
)
case "adamw":
optimizer = torch.optim.AdamW(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
)
case _:
raise NameError(f"Unknown optimizer {optimizer}")
sched_use = self.scheduler_dict.get("use", None)
monitor = self.scheduler_dict.get("monitor", "val_loss")
hyper_params = self.scheduler_dict.get(sched_use, {})
# Create scheduler
if sched_use == "cosine":
# Cosine annealing from pytorch
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, **hyper_params
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
},
}
elif sched_use == "cosine_warmup":
warmup_ratio = hyper_params.get("warmup_ratio", 0.1)
# SAFEGUARD: Calculate steps
total_steps = self.trainer.estimated_stepping_batches
lgr_logger.debug(f"Scheduler initialized with TOTAL_STEPS = {total_steps}")
# Check for edge cases where Lightning returns infinity or valid steps are unknown
if isinstance(total_steps, (float, int)) and (
total_steps == float("inf") or total_steps == 0
):
lgr_logger.warning(
"Warning: Could not calculate total steps automatically."
)
total_steps = hyper_params.get("total_steps", 3000)
num_warmup_steps = int(total_steps * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 1,
},
}
elif sched_use == "plateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **hyper_params
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": monitor, # Metric to monitor
"interval": "epoch",
"frequency": 1,
},
}
elif sched_use == "exp":
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, **hyper_params
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
},
}
else:
# No scheduler
return optimizer