import gc
from collections import Counter
import torch
import torch.nn.functional as F
import wandb
from lightning.pytorch.utilities import grad_norm
from torchmetrics import (
Accuracy,
AUROC,
CohenKappa,
F1Score,
MatthewsCorrCoef,
Precision,
Recall,
)
from sdofmv2.core import BaseModule
from .focal_loss import focal_loss_multiclass
from .head_networks import ClsLinear, SimpleLinear, SkipLinearHead, TransformerHead
[docs]
class SWClassifier(BaseModule):
"""Solar Wind Classifier using a backbone encoder and configurable head.
This module wraps a pretrained backbone encoder and adds a classification head
for solar wind prediction tasks. It supports multiple head types (linear,
transformer, skip_linear), handles coordinate embeddings, and tracks various
classification metrics during training, validation, and testing.
Args:
channels (list[str]): List of data channels to use.
num_classes (int): Number of output classes.
class_names (list[str]): Names of the classes for logging.
max_position_element (int): Maximum power for positional encoding.
backbone: The pretrained backbone model.
freeze_encoder (bool): Whether to freeze the backbone encoder.
plt_style: Plotting style configuration.
head_type (str): Type of classification head ("linear", "transformer", "skip_linear").
hidden_dim (int): Hidden dimension for the head network.
position_size (int): Number of position coordinates.
p_drop (float): Dropout probability.
nhead (int): Number of attention heads for transformer head.
embed_dim (int): Embedding dimension.
skips (list[int]): Layer indices for skip connections.
include_raw_coordinates (bool): Whether to include raw coordinates.
num_hidden_layers (int): Number of hidden layers for skip_linear head.
radial_mean (float): Mean for radial normalization.
radial_std (float): Standard deviation for radial normalization.
loss_dict (dict): Loss function configuration.
optimizer_dict (dict): Optimizer configuration.
scheduler_dict (dict): Scheduler configuration.
"""
def __init__(
self,
# for finetuning
channels=None,
num_classes=None,
class_names=None,
max_position_element=None,
backbone=None,
freeze_encoder=True,
plt_style=None,
head_type="linear",
hidden_dim=64,
position_size=4,
p_drop=0.1,
nhead=8,
embed_dim=512,
skips=[4],
include_raw_coordinates=False,
num_hidden_layers=8,
radial_mean=None,
radial_std=None,
# loss, optimizer, and scheduler
loss_dict=None,
optimizer_dict=None,
scheduler_dict=None,
# all else
*args,
**kwargs,
):
super().__init__(
optimizer_dict=optimizer_dict,
scheduler_dict=scheduler_dict,
*args,
**kwargs,
)
self.freeze_encoder = freeze_encoder
self.backbone = backbone
self.mask_ratio = self.backbone.masking_ratio
self.input_feature_dim = int(
self.backbone.hparams.embed_dim * (550) * self.mask_ratio
)
self.plt_style = plt_style
self.num_classes = num_classes
self.class_names = class_names
self.channels = sorted(channels)
self.id_193 = self.channels.index("193A") if "193A" in self.channels else None
self.max_position_element = max_position_element
self.head_type = head_type
self.hidden_dim = hidden_dim
self.skips = skips
self.include_raw_coordinates = include_raw_coordinates
self.num_hidden_layers = num_hidden_layers
self.p_drop = p_drop
self.nhead = nhead
self.embed_dim = embed_dim
self.position_size = position_size
self.radial_mean = radial_mean
self.radial_std = radial_std
self.position_size = position_size
self.attn_maps = []
self.loss_dict = loss_dict
# evaluation metrics for training
self.train_acc = Accuracy(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.train_f1 = F1Score(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.val_f1 = F1Score(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.test_f1 = F1Score(
task="multiclass", num_classes=self.num_classes, average="macro"
)
# evaluation metrics for validation
self.val_precision = Precision(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.val_recall = Recall(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.val_acc = Accuracy(
task="multiclass", num_classes=self.num_classes, average="micro"
) # Single value
self.val_auroc = AUROC(
task="multiclass", num_classes=self.num_classes, average="macro"
) # Single value
self.val_mcc = MatthewsCorrCoef(
task="multiclass", num_classes=self.num_classes
) # Single value
self.val_kappa = CohenKappa(
task="multiclass", num_classes=self.num_classes
) # Single value
# evaluation metrics for validation
self.test_precision = Precision(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.test_recall = Recall(
task="multiclass", num_classes=self.num_classes, average="macro"
)
self.test_acc = Accuracy(
task="multiclass", num_classes=self.num_classes, average="macro"
) # Single value
self.test_auroc = AUROC(
task="multiclass", num_classes=self.num_classes, average="macro"
) # Single value
self.test_mcc = MatthewsCorrCoef(
task="multiclass", num_classes=self.num_classes
) # Single value
self.test_kappa = CohenKappa(task="multiclass", num_classes=self.num_classes)
# Storage for validation data
self.val_all_imgs = []
self.val_all_timestamps = []
self.val_all_targets = []
self.val_all_preds = []
self.val_all_positions = []
self.val_barplot_targets = []
self.val_barplot_preds = []
# freeze or unfreeze backbone
if freeze_encoder:
for param in self.backbone.parameters():
param.requires_grad = False
# Set the whole backbone to eval mode (dropout, batchnorm, etc.)
self.backbone.eval()
torch.set_grad_enabled(False)
else:
# backbone trainable
for param in self.backbone.parameters():
param.requires_grad = True
self.backbone.train()
# Define head network
match self.head_type:
case "linear":
self.head = SimpleLinear(
# virtual eve
d_output=self.num_classes, # number of classes
input_feature_dim=self.input_feature_dim,
max_position_element=self.max_position_element,
hidden_dim=self.hidden_dim,
position_size=self.position_size,
p_drop=self.p_drop,
)
case "transformer":
self.head = TransformerHead(
d_output=self.num_classes,
input_token_dim=self.embed_dim,
p_drop=self.p_drop,
max_position_element=self.max_position_element,
nhead=self.nhead,
)
case "skip_linear":
self.head = SkipLinearHead(
# virtual eve
d_output=self.num_classes, # number of classes
input_feature_dim=self.input_feature_dim,
max_position_element=self.max_position_element,
position_size=self.position_size,
hidden_dim=self.hidden_dim,
skips=self.skips,
include_raw_coordinates=self.include_raw_coordinates,
num_hidden_layers=self.num_hidden_layers,
)
# TODO: add cls_linear model
case _:
raise ValueError("Invalid head type!")
# define loss
loss_args = self.loss_dict.get(self.loss_dict.use, "focal")
match self.loss_dict.use:
case "cross_entropy":
self.loss_fn = lambda inputs, targets: F.cross_entropy(
inputs,
targets,
weight=loss_args.class_weights,
reduction=loss_args.reduction,
)
case "focal":
self.loss_fn = lambda inputs, targets: focal_loss_multiclass(
inputs,
targets,
alpha=loss_args.alpha,
gamma=loss_args.gamma,
reduction=loss_args.reduction,
)
[docs]
def forward(self, x, position, r_distance):
"""Perform a forward pass through the classifier.
Args:
x (torch.Tensor): Input images of shape (B, C, H, W).
position (torch.Tensor): Position coordinates of shape (B, position_size).
r_distance (torch.Tensor): Radial distance values of shape (B,).
Returns:
torch.Tensor: Class logits of shape (B, num_classes).
"""
latent, mask, ids_restore = self.backbone.forward_encoder(x, self.mask_ratio)
# head layer
y_hat = self.head(latent, position, r_distance)
return y_hat
[docs]
def forward_analysis(self, x):
"""Perform analysis forward pass for reconstruction visualization.
Args:
x (torch.Tensor): Input images of shape (B, C, H, W).
Returns:
torch.Tensor: Reconstructed images.
"""
loss, x_hat, mask = self.backbone.autoencoder(x, self.mask_ratio)
x_hat = self.backbone.autoencoder.unpatchify(x_hat)
return x_hat
[docs]
def predict_step(self, batch, batch_idx):
"""Perform a prediction step for inference.
Args:
batch: A tuple containing (images, timestamps, position, r_distance, target).
batch_idx: The index of the current batch.
Returns:
dict: A dictionary containing predictions, targets, embeddings, etc.
"""
imgs, timestamps, position, r_distance, target = batch
with torch.no_grad():
y_hat = self(imgs, position, r_distance)
latent, mask, ids_restore = self.backbone.forward_encoder(
imgs, self.mask_ratio
)
loss, x_hat, mask = self.backbone.autoencoder(imgs, self.mask_ratio)
x_hat = self.backbone.autoencoder.unpatchify(x_hat)
preds = torch.argmax(y_hat, dim=1)
return {
"timestamps": timestamps,
"imgs": imgs,
"x_hat": x_hat,
"predictions": preds,
"targets": target,
"logits": y_hat,
"probabilities": torch.softmax(y_hat, dim=1),
"embeddings": latent,
}
[docs]
def training_step(self, batch, batch_idx):
"""Perform a single training step.
Args:
batch: A tuple containing (images, timestamps, position, r_distance, target).
batch_idx: The index of the current batch.
Returns:
torch.Tensor: The training loss value.
"""
imgs, timestamps, position, r_distance, target = batch # [batch, c, 512, 512]
y_hat = self(imgs, position, r_distance)
# Calculate loss
loss = self.loss_fn(y_hat, target)
# Update metrics
preds = torch.argmax(y_hat, dim=1)
self.train_f1.update(preds, target)
self.train_acc.update(preds, target)
self.log(
"train_loss",
loss,
on_step=True, # Log every step
on_epoch=True, # Log at end of epoch
prog_bar=True, # Show in progress bar
logger=True,
sync_dist=True,
)
# Log current learning rate from optimizer
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log("lr", lr, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
self.log(
"train_f1",
self.train_f1,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"train_acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
return loss
[docs]
def validation_step(self, batch, batch_idx):
"""Perform a single validation step.
Args:
batch: A tuple containing (images, timestamps, position, r_distance, target).
batch_idx: The index of the current batch.
Returns:
torch.Tensor: The validation loss value.
"""
imgs, timestamps, position, r_distance, target = batch
y_hat = self(imgs, position, r_distance)
# calculate loss
val_loss = self.loss_fn(y_hat, target)
# Update metrics
preds = torch.argmax(y_hat, dim=1)
self.val_f1.update(preds, target)
self.val_precision.update(preds, target)
self.val_recall.update(preds, target)
self.val_acc.update(preds, target)
self.val_auroc.update(y_hat, target)
self.val_mcc.update(preds, target)
self.val_kappa.update(preds, target)
self.log(
"val_loss",
val_loss,
on_step=False, # Log every step
on_epoch=True, # Log at end of epoch
prog_bar=True, # Show in progress bar
logger=True,
sync_dist=True,
)
# Store data for epoch-end analysis
if batch_idx < 5: # Only store first few batches
# Store data for epoch-end analysis
self.val_all_imgs.append(imgs.cpu())
self.val_all_targets.append(target.cpu())
self.val_all_preds.append(preds.cpu())
self.val_all_timestamps.append(timestamps.cpu())
self.val_all_positions.append(position.cpu())
self.val_barplot_targets.append(target.cpu())
self.val_barplot_preds.append(preds.cpu())
return val_loss
[docs]
def test_step(self, batch, batch_idx):
"""Perform a single test step.
Args:
batch: A tuple containing (images, timestamps, position, r_distance, target).
batch_idx: The index of the current batch.
Returns:
dict: A dictionary containing predictions, targets, logits, and test loss.
"""
imgs, timestamps, position, r_distance, target = batch # [batch, c, 512, 512]
y_hat = self(imgs, position, r_distance)
test_loss = self.loss_fn(y_hat, target)
preds = torch.argmax(y_hat, dim=1)
# Update the F1 metric with predictions and targets
self.test_f1.update(preds, target)
self.test_precision.update(preds, target)
self.test_recall.update(preds, target)
self.test_acc.update(preds, target)
self.test_auroc.update(y_hat, target)
self.test_mcc.update(preds, target)
self.test_kappa.update(preds, target)
self.log(
"test_loss",
test_loss,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"test_f1",
self.test_f1,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
return {
"predictions": preds,
"targets": target,
"logits": y_hat,
"probabilities": torch.softmax(y_hat, dim=1),
"attn_maps": self.attn_maps,
"test_loss": test_loss,
}
[docs]
def on_validation_epoch_end(self):
"""Called at the end of the validation epoch.
Computes and logs all validation metrics, generates WandB plots
(confusion matrix, class distributions), and clears stored buffers.
"""
# Compute and log all accumulated metrics
self.log(
"val_f1",
self.val_f1.compute(),
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"val_precision",
self.val_precision.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"val_recall",
self.val_recall.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"val_acc",
self.val_acc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"val_auroc",
self.val_auroc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"val_mcc",
self.val_mcc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"val_kappa",
self.val_kappa.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
# WandB logging (only at epoch end to avoid slowdown)
if len(self.val_all_preds) > 0 and len(self.val_all_targets) > 0:
# Concatenate all stored data
# all_imgs = torch.cat(self.val_all_imgs, dim=0)
# all_timestamps = torch.cat(self.val_all_timestamps, dim=0)
# all_targets = torch.cat(self.val_all_targets, dim=0)
# all_preds = torch.cat(self.val_all_preds, dim=0)
# all_positions = torch.cat(self.val_all_positions, dim=0)
all_barplot_targets = torch.cat(self.val_barplot_targets, dim=0)
all_barplot_preds = torch.cat(self.val_barplot_preds, dim=0)
targets_list = all_barplot_targets.detach().cpu().tolist()
preds_list = all_barplot_preds.detach().cpu().tolist()
# confusion matrix
self.logger.experiment.log(
{
f"conf_mat_epoch_{self.current_epoch}": wandb.plot.confusion_matrix(
probs=None,
y_true=targets_list,
preds=preds_list,
class_names=self.class_names,
title=f"confusion_matrix_epoch_{self.current_epoch}",
)
}
)
# Count occurrences
true_counts = Counter(targets_list)
pred_counts = Counter(preds_list)
# All unique classes
all_classes = sorted(set(true_counts.keys()) | set(pred_counts.keys()))
# Create table with separate columns for true/predicted
data = []
for cls in all_classes:
data.append(
[str(cls), true_counts.get(cls, 0), pred_counts.get(cls, 0)]
)
table = wandb.Table(
data=data, columns=["class", "true_count", "pred_count"]
)
# You can only plot one value column at a time, so make two plots
self.logger.experiment.log(
{
f"true_class_distribution_epoch_{self.current_epoch}": wandb.plot.bar(
table,
"class",
"true_count",
title=f"True Class Distribution_epoch_{self.current_epoch}",
),
f"predicted_class_distribution_epoch_{self.current_epoch}": wandb.plot.bar(
table,
"class",
"pred_count",
title=f"Predicted Class Distribution_epoch_{self.current_epoch}",
),
}
)
# Clear stored data to free memory
self.val_all_imgs.clear()
self.val_all_timestamps.clear()
self.val_all_targets.clear()
self.val_all_preds.clear()
self.val_all_positions.clear()
self.val_barplot_targets.clear()
self.val_barplot_preds.clear()
for metric in [
self.val_f1,
self.val_precision,
self.val_recall,
self.val_acc,
self.val_auroc,
self.val_mcc,
self.val_kappa,
]:
metric.reset()
gc.collect()
torch.cuda.empty_cache()
[docs]
def on_train_epoch_end(self):
"""Called at the end of the training epoch.
Performs garbage collection and clears CUDA cache to free memory.
"""
gc.collect()
torch.cuda.empty_cache()
[docs]
def on_test_epoch_end(self):
"""Called at the end of the test epoch.
Computes and logs all test metrics including F1, precision, recall,
accuracy, AUROC, MCC, and Cohen's Kappa.
"""
# Log all test metrics
self.log(
"test_f1",
self.test_f1.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_precision",
self.test_precision.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_recall",
self.test_recall.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_acc",
self.test_acc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_auroc",
self.test_auroc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_mcc",
self.test_mcc.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
self.log(
"test_kappa",
self.test_kappa.compute(),
on_epoch=True,
logger=True,
sync_dist=True,
)
[docs]
def on_before_optimizer_step(self, optimizer):
"""Called before each optimizer step to log gradient norms.
Args:
optimizer: The optimizer about to perform an update step.
"""
# Compute the 2-norm for each layer
# If using mixed precision, the gradients are already unscaled here
norms = grad_norm(self.head, norm_type=2)
self.log_dict(norms)