Source code for sdofmv2.tasks.f107.f107_module

import torch
import torch.nn as nn
from sdofmv2.core import BaseModule


[docs] class MultiLayerPerceptron(BaseModule): """Multi-layer perceptron head for processing backbone features. This class implements a regression or classification head that sits on top of a pre-trained backbone. It extracts latent representations from the backbone, aggregates patch tokens using both mean and max pooling, and processes the combined features through a series of fully connected layers. Args: backbone (nn.Module): The feature extraction model containing an autoencoder. freeze (bool): Whether to freeze the backbone parameters to prevent training. input_dim (int): The dimensionality of the backbone's latent features. The internal MLP input dimension is twice this value due to the concatenation of mean and max pooled features. output_dim (int, optional): The number of output units. Defaults to 1. hidden_layer_dims (list[int], optional): Dimensions of the hidden MLP layers. Defaults to [512, 512, 512]. dropout (float, optional): Dropout probability for regularization. Defaults to 0.0. mask_ratio (float, optional): Fraction of input patches to mask during the forward pass. Defaults to 0.0. optimizer_dict (dict, optional): Configuration for the optimizer. Defaults to None. scheduler_dict (dict, optional): Configuration for the learning rate scheduler. Defaults to None. Returns: torch.Tensor: The output logits or predictions from the final linear layer. """ def __init__( self, backbone, freeze, input_dim, output_dim=1, hidden_layer_dims=[512, 512, 512], dropout=0.0, mask_ratio=0.0, optimizer_dict=None, scheduler_dict=None, ): super().__init__(optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict) self.backbone = backbone self.freeze_backbone = freeze self.nans = [] if self.freeze_backbone: self.backbone.eval() for param in self.backbone.parameters(): param.requires_grad = False self.mask_ratio = mask_ratio self.norm = nn.LayerNorm(input_dim * 2) # Define the dimensions of the MLP layers dims = [input_dim * 2] + hidden_layer_dims # Define the dropout layer self.dropout = nn.Dropout(p=dropout) # Define the fully connected layers self.fcs = nn.ModuleList( [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] ) # Define the activation function self.acts = nn.ModuleList([nn.LeakyReLU(0.01) for _ in range(len(dims) - 1)]) # Define the output layer self.fc_out = nn.Linear(dims[-1], output_dim) # Define the loss function self.criterion = nn.MSELoss() # Initialize a dictionary to store test predictions self.test_preds = {}
[docs] def forward(self, x): """Processes input through the backbone and MLP head. Args: x (torch.Tensor): Input image or data tensor. Returns: torch.Tensor: Output logits of shape (batch_size, output_dim). """ if self.freeze_backbone: with torch.no_grad(): # latent shape: [Batch, Num_Patches + 1, Hidden_Dim] latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder( x, mask_ratio=self.mask_ratio ) else: latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder( x, mask_ratio=self.mask_ratio ) patch_tokens = latent[:, 1:, :] x_avg = patch_tokens.mean(dim=1) x_max = patch_tokens.max(dim=1).values x_cls = torch.cat([x_avg, x_max], dim=-1) x_cls = self.norm(x_cls) for fc, act in zip(self.fcs, self.acts): x_cls = self.dropout(x_cls) x_cls = fc(x_cls) x_cls = act(x_cls) logits = self.fc_out(x_cls) return logits
def on_train_start(self): if self.freeze_backbone: self.backbone.eval()
[docs] def training_step(self, batch, batch_idx): # Training step imgs, timestamps, y = batch logits = self(imgs).squeeze(-1) loss = self.criterion(logits, y) self.log("train_loss", loss, prog_bar=True) return loss
[docs] def validation_step(self, batch, batch_idx): # Validation step imgs, timestamps, y = batch logits = self(imgs).squeeze(-1) loss = self.criterion(logits, y) self.log("val_loss", loss, prog_bar=True) return loss
def test_step(self, batch, batch_idx): # Test step imgs, timestamps, y = batch logits = self(imgs).squeeze(-1) loss = self.criterion(logits, y) self.log("test_loss", loss, prog_bar=True) preds_real = logits.detach().cpu().numpy() labels_real = y.cpu().numpy() # Save results per timestamp for t, label, pred in zip(timestamps, labels_real, preds_real): self.test_preds[t.item()] = [label.item(), pred.item()] return loss def on_before_optimizer_step(self, optimizer): # Compute the norm of the gradients grad_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) # Check if gradients are exploding or NaN if torch.isnan(grad_norm) or torch.isinf(grad_norm): print( "SKIPPING STEP: Gradients are NaN/Inf! Weights saved from corruption." ) # Only unscale if a scaler actually exists (i.e., if using fp16) if getattr(self.trainer, "scaler", None) is not None: self.trainer.scaler.unscale_(optimizer) optimizer.zero_grad() # Clear the bad gradients (Don't update weights!) return