Source code for sdofmv2.core.losses

import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Optional, Literal


def _get_group_mean(loss: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Compute mean loss for a group defined by a mask, safely handling empty groups.

    Args:
        loss: Element-wise loss tensor.
        mask: Boolean mask of same shape as loss (for 2D) or broadcastable.

    Returns:
        Scalar mean loss for the group, or 0.0 if the group is empty.
    """
    return loss[mask].mean() if mask.any() else torch.tensor(0.0, device=loss.device)


[docs] def mae_loss(pred, target) -> torch.Tensor: """Calculates the mean absolute error between predictions and targets. Args: pred (torch.Tensor): Predicted values from the model. target (torch.Tensor): Ground truth values to compare against. Returns: torch.Tensor: The calculated mean absolute error as a scalar tensor. """ err = pred - target return torch.mean(torch.abs(err))
[docs] def vector_aware_loss(pred, target, base_loss) -> torch.Tensor: """Calculates a loss that combines magnitude and orientation for vector fields. This method computes a base loss (MSE or MAE) and adds a weighted cosine similarity term. The cosine similarity component enforces directional alignment between the predicted and target vectors. Args: pred (torch.Tensor): Predicted vector field of shape (B, 3, F, H, W). target (torch.Tensor): Ground truth vector field of shape (B, 3, F, H, W). base_loss (str): The type of base loss to compute, either "mse" or "mae". Returns: torch.Tensor: A scalar tensor representing the combined loss. Raises: ValueError: If the provided base_loss type isn't supported. """ if base_loss == "mse": baseloss = ((pred - target) ** 2).mean() elif base_loss == "mae": baseloss = (torch.abs(pred - target)).mean() else: raise ValueError(f"Not supported loss type: {base_loss}") # preds, target: [B, 3, Frame, H, W] cos_sim = 1 - F.cosine_similarity(pred, target, dim=1).mean() loss = baseloss + 0.1 * cos_sim return loss
[docs] def pixel_weight_loss( pred, target_norm, target, base_loss, threshold, ar_weight_ratio: float, ): """ Args: pred (4d tensor): output from model target_norm (4d tensor): re-normalized target by norm_pix_loss target (4d tensor): normalized target base_loss (str): baseline loss function threshold (float): threshold for pixels which have strong magnetic field ar_weight_ratio (float): weight for the pixesl greater than threshold Returns: _type_: torch.float """ loss = _get_base_loss(pred, target_norm, base_loss) # Create masks for high magnetic field (ar) and low/no field (noise) regions is_ar_region = torch.abs(target) > threshold is_noise_region = ~is_ar_region mean_ar = _get_group_mean(loss, is_ar_region) mean_noise = _get_group_mean(loss, is_noise_region) # Normalize weights weight_for_ar = ar_weight_ratio / (ar_weight_ratio + 1) weight_for_noise = 1 / (ar_weight_ratio + 1) return (weight_for_ar * mean_ar) + (weight_for_noise * mean_noise)
[docs] def patch_weight_loss(pred, target, loss_dict, mask_hidden, mask_off_limb): """Calculates a three-tier weighted reconstruction loss for solar data. This function separates patches into three categories (masked inner disk, visible inner disk, and off-limb space) and applies independent weights to each group's mean loss. This prevents the large population of space pixels or masked patches from disproportionately biasing the gradients. Args: pred (torch.Tensor): Predicted patch values [B, L, D]. target (torch.Tensor): Ground truth (potentially normalized) patches [B, L, D]. loss_dict (dict or object): Config object containing: * base_loss (dict): Must have 'type' ('mse', 'mae', or 'huber') and 'delta' (for huber). * weight_on_patches (list[float]): A three-element list: [weight_masked_inner, weight_visible_inner, weight_off_limb]. Example: [0.7, 0.2, 0.1]. mask_hidden (torch.Tensor): Binary/bool mask from encoder [B, L]. 1 (True) indicates a masked/hidden patch. mask_off_limb (torch.Tensor): Binary/bool spatial mask [B, L]. 1 (True) indicates a patch outside the solar disk. Returns: torch.Tensor: Scalar weighted mean loss. Raises: ValueError: If an unsupported loss type is provided. IndexError: If weight_on_patches does not contain exactly three elements. """ base_loss_type = loss_dict.base_loss.get("type", "mse") huber_delta = loss_dict.base_loss.get("delta", 1.0) # Extract 3-tier weights weights_raw = loss_dict.get("weight_on_patches", [0.7, 0.2, 0.1]) if len(weights_raw) < 3: raise IndexError("weight_on_patches must have 3 elements for the 3-tier strategy.") # Base Loss Calculation loss = _get_base_loss(pred, target, base_loss_type, huber_delta) # Define the three tiers using boolean logic # Tier 1: Hidden patches inside the solar disk if mask_off_limb is not None: is_masked_inner = mask_hidden.bool() & (~mask_off_limb.bool()) # Tier 2: Visible patches inside the solar disk is_visible_inner = (~mask_hidden.bool()) & (~mask_off_limb.bool()) # Tier 3: All patches outside the solar disk (space) is_space = mask_off_limb.bool() else: raise ValueError("off limb mask is None!") mean_masked = _get_group_mean(loss, is_masked_inner) mean_visible = _get_group_mean(loss, is_visible_inner) mean_space = _get_group_mean(loss, is_space) # Apply Normalized Weights w_sum = sum(weights_raw) w_m, w_v, w_l = [w / w_sum for w in weights_raw] final_loss = (w_m * mean_masked) + (w_v * mean_visible) + (w_l * mean_space) return final_loss
# ============================================================================= # Patch-level loss functions for non-zero vs all-zero patches # ============================================================================= def _get_base_loss( pred: torch.Tensor, target: torch.Tensor, base_type: Literal["mse", "mae", "huber"], huber_delta: float = 1.0, ) -> torch.Tensor: """Compute element-wise base loss between predictions and targets. Args: pred: Predicted tensor [B, L, D] target: Target tensor [B, L, D] base_type: Type of loss - "mse", "mae", or "huber" huber_delta: Delta parameter for Huber loss Returns: Element-wise loss tensor of same shape as pred/target """ if base_type == "mse": return (pred - target) ** 2 elif base_type == "mae": return torch.abs(pred - target) elif base_type == "huber": return F.huber_loss(pred, target, reduction="none", delta=huber_delta) else: raise ValueError(f"Not supported base loss type: {base_type}") def _get_zero_pixel_mask_from_target(imgs, patch_size=16, corner_size=4, corner_ratio=0.25): B, C, T, H, W = imgs.shape p = patch_size imgs_avg = imgs.mean(dim=2) corners = torch.cat( [ imgs_avg[:, :, :corner_size, :corner_size].reshape(B, C, -1), imgs_avg[:, :, :corner_size, -corner_size:].reshape(B, C, -1), imgs_avg[:, :, -corner_size:, :corner_size].reshape(B, C, -1), imgs_avg[:, :, -corner_size:, -corner_size:].reshape(B, C, -1), ], dim=-1, ) corner_mean = corners.mean(dim=-1) threshold = corner_ratio * corner_mean threshold_expanded = threshold.unsqueeze(-1).unsqueeze(-1) is_zero_pixel = imgs_avg < threshold_expanded is_zero_pixel = rearrange(is_zero_pixel, "b c (h p) (w q) -> b (h w) (p q c)", p=p, q=p) return is_zero_pixel
[docs] def split_pixel_loss( pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, beta: float = 1.0, base_type: Literal["mse", "mae", "huber"] = "mse", huber_delta: float = 1.0, imgs: torch.Tensor | None = None, patch_size: int = 16, corner_size: int = 4, corner_ratio: float = 0.25, ) -> torch.Tensor: element_loss = _get_base_loss(pred, target, base_type, huber_delta) is_zero_pixel = _get_zero_pixel_mask_from_target( imgs, patch_size=patch_size, corner_size=corner_size, corner_ratio=corner_ratio ) is_nonzero_pixel = ~is_zero_pixel loss_nonzero = _get_group_mean(element_loss, is_nonzero_pixel) loss_zero = _get_group_mean(element_loss, is_zero_pixel) total_loss = alpha * loss_nonzero + beta * loss_zero return total_loss
[docs] def sparse_dense_loss( pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, beta: float = 1.0, base_type: Literal["mse", "mae", "huber"] = "mse", huber_delta: float = 1.0, imgs: torch.Tensor | None = None, patch_size: int = 16, corner_size: int = 4, corner_ratio: float = 0.25, ) -> torch.Tensor: element_loss = _get_base_loss(pred, target, base_type, huber_delta) is_zero_pixel = _get_zero_pixel_mask_from_target( imgs, patch_size=patch_size, corner_size=corner_size, corner_ratio=corner_ratio ) is_nonzero_pixel = ~is_zero_pixel recon_loss = _get_group_mean(element_loss, is_nonzero_pixel) zero_pred = pred[is_zero_pixel] embedding_size = ( (zero_pred**2).sum(dim=-1).mean() / zero_pred.shape[-1] if zero_pred.numel() > 0 else torch.tensor(0.0, device=pred.device) ) total_loss = alpha * recon_loss + beta * embedding_size return total_loss