Source code for sdofmv2.tasks.solar_wind.focal_loss

from typing import Union

import torch
import torch.nn.functional as F


[docs] def focal_loss_multiclass( inputs: torch.Tensor, targets: torch.Tensor, alpha: Union[float, torch.Tensor] = 0.25, gamma: float = 2.0, reduction: str = "none", ) -> torch.Tensor: """ Multi-class focal loss based on torchvision.ops.sigmoid_focal_loss. Args: inputs (Tensor[N, C]): Logits for each class. targets (Tensor[N]): Class indices (0 ≤ targets < C). alpha (float or Tensor[C]): Balance factor(s). Scalar or per-class. gamma (float): Modulating factor exponent. reduction (str): 'none', 'mean', or 'sum'. Returns: Tensor: Loss per sample, or reduced loss. """ ce_loss = F.cross_entropy(inputs, targets, reduction="none") p_t = torch.exp(-ce_loss) if not isinstance(alpha, torch.Tensor): alpha = torch.tensor(alpha, device=targets.device, dtype=ce_loss.dtype) alpha_t = alpha[targets] loss = alpha_t * (1 - p_t) ** gamma * ce_loss if reduction == "mean": return loss.mean() elif reduction == "sum": return loss.sum() elif reduction == "none": return loss else: raise ValueError( f"Invalid reduction mode: {reduction}. Choose 'none', 'mean', or 'sum'." )