Source code for sdofmv2.utils.utils

# Assortment of variously useful functions
import collections.abc
import warnings

# Third-party libraries
import numpy as np
import torch
from einops import rearrange

# Astronomy / SunPy libraries
import astropy.units as u
from astropy.coordinates import SkyCoord
import sunpy.data.sample
import sunpy.map
from sunpy.coordinates.frames import HeliographicStonyhurst


# GENERAL
[docs] def days_hours_mins_secs_str(total_seconds): """Convert a duration in seconds to a human-readable string. Args: total_seconds (int or float): The total number of seconds. Returns: str: A formatted string in the format 'Dd:Hh:Mm:Ss'. """ d, r = divmod(total_seconds, 86400) h, r = divmod(r, 3600) m, s = divmod(r, 60) return "{0}d:{1:02}:{2:02}:{3:02}".format(int(d), int(h), int(m), int(s))
[docs] def flatten_dict(d, parent_key="", sep="_"): """Flatten a nested dictionary into a single-level dictionary. Args: d (dict): The input dictionary to flatten. parent_key (str, optional): The prefix for nested keys. Defaults to "". sep (str, optional): The separator between parent and child keys. Defaults to "_". Returns: dict: A flattened dictionary with keys joined by the separator. """ items = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items)
[docs] def unflatten_dict(dictionary, sep="_", wandb_mode=True): """Unflatten a dictionary back into a nested dictionary structure. Args: dictionary (dict): The flattened dictionary to unflatten. sep (str, optional): The separator used to join keys. Defaults to "_". wandb_mode (bool, optional): If True, extracts values from 'value' keys in nested dicts. Defaults to True. Returns: AttributeDict: A nested dictionary with keys split by the separator. """ def grab_values(d): resultDict = AttributeDict() for k, v in d.items(): if isinstance(v, dict) and "desc" in v.keys() and "value" in v.keys(): resultDict[k] = v["value"] elif isinstance(v, dict): resultDict[k] = grab_values(v) else: resultDict[k] = v return resultDict resultDict = AttributeDict() for key, value in dictionary.items(): # value = value.value if wandb_mode else value parts = key.split(sep) d = resultDict for part in parts[:-1]: if part not in d: d[part] = AttributeDict() d = d[part] d[parts[-1]] = value if wandb_mode: resultDict = grab_values(resultDict) return resultDict
#### MAE FUNCTIONS
[docs] def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """Generate 1D sine-cosine positional embeddings. Args: embed_dim (int): The output dimension for each position (must be even). pos (ndarray): A list or array of positions to be encoded, shape (M,). Returns: ndarray: Positional embeddings of shape (M, embed_dim). """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) # pos = pos.reshape(-1) # (M,), should already be this out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb
# -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # --------------------------------------------------------
[docs] def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """Generate 2D sine-cosine positional embeddings. Args: embed_dim (int): The embedding dimension for each position. grid_size (int): The grid height and width (assumed square). cls_token (bool, optional): If True, prepends a zero vector for CLS token. Defaults to False. Returns: ndarray: Positional embeddings of shape [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (with cls_token). """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed
[docs] def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): """Generate 2D sine-cosine positional embeddings from a grid. Args: embed_dim (int): The embedding dimension (must be even). grid (ndarray): A 2xHxW array containing the 2D grid coordinates. Returns: ndarray: The positional embeddings of shape (H*W, embed_dim). """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb
[docs] def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """Generate 3D sine-cosine positional embeddings. Args: embed_dim (int): The embedding dimension (must be divisible by 16). grid_size (tuple): A 3-tuple of (t, h, w) representing the grid dimensions. cls_token (bool, optional): If True, prepends a zero vector for CLS token. Defaults to False. Returns: ndarray: Positional embeddings of shape (L, embed_dim) where L = t * h * w (or L = 1 + t * h * w with cls_token). """ assert embed_dim % 16 == 0 t_size, h_size, w_size = grid_size w_embed_dim = embed_dim // 16 * 6 h_embed_dim = embed_dim // 16 * 6 t_embed_dim = embed_dim // 16 * 4 w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed
# HMI MASKING # From various FDL piror works (sdolatent, solar-vae, etc.)
[docs] def hmi_mask(hmi_data): """Generate a binary mask for HMI solar disk data. Creates a binary mask where 1 indicates pixels within the solar disk (non-zero magnetic field values) and 0 indicates pixels outside. Args: hmi_data (torch.Tensor): The HMI magnetogram data tensor. Returns: torch.Tensor: A binary mask tensor of the same shape as input. """ return (torch.abs(hmi_data) > 0.0).to(dtype=torch.uint8)
[docs] def apply_hmi_mask(data, hmi_mask, value): """Apply an HMI mask to solar image data. Replaces pixels outside the solar disk (where hmi_mask is 0) with a specified scalar value. The mask is applied only to HMI channels; AIA channels remain unchanged. Args: data (torch.Tensor): The input data tensor of shape (B, C, H, W) or (C, H, W). hmi_mask (torch.Tensor): A binary mask where 1 represents pixels inside the solar disk and 0 represents pixels outside. value (float): The scalar value to replace masked pixels with. Returns: torch.Tensor: The masked data tensor with the same shape as input. """ # hmi mask is a binary mask of 0 and 1 values # 1 represents that the pixel is within the solar disk, 0 represents that the pixel is outside the solar disk # this function replaces the pixels outside the solar disk with the given scalar value if data.ndim == 4: hmi = data[:, :3] aia = data[:, 3:] value_mask = value * (~hmi_mask.to(dtype=torch.bool)) hmi_mask = hmi_mask.to(device=data.device) value_mask = value_mask.to(device=data.device) hmi = hmi * hmi_mask hmi = hmi + value_mask data = torch.cat([hmi, aia], dim=1) return data elif data.ndim == 3: data = data.unsqueeze(0) data = apply_hmi_mask(data, hmi_mask, value) data = data.squeeze(0) return data else: raise ValueError("Expecting 3d or 4d data")
[docs] class AttributeDict(dict): """A dictionary subclass that allows attribute-style access to its keys. This class lets you use dot notation like obj.key to get and set dictionary items. It keeps all standard dictionary methods and uses __slots__ to save memory by preventing the creation of an instance dictionary. Args: *args: Positional arguments passed to the dict constructor. **kwargs: Keyword arguments passed to the dict constructor. Returns: A new AttributeDict instance. """ __slots__ = () __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__
[docs] def stonyhurst_to_patch_index(lat, lon, patch_size, img_w=512, img_h=512): """Convert Heliographic Stonyhurst coordinates to patch indices. Transforms latitude and longitude coordinates in the Heliographic Stonyhurst frame to corresponding patch indices in an image grid. Args: lat (float): Latitude in degrees. lon (float): Longitude in degrees. patch_size (int): The size of each patch in pixels. img_w (int, optional): Image width in pixels. Defaults to 512. img_h (int, optional): Image height in pixels. Defaults to 512. Returns: torch.Tensor: A tensor of shape (2,) containing the patch indices [x, y]. Warns: UserWarning: If image dimensions exceed 1024, indicating potential precision loss in coordinate conversion. """ # Heliographic Stonyhurst coordinates to patch index aiamap = sunpy.map.Map(sunpy.data.sample.AIA_171_IMAGE) # example image is loaded at 1024x1024 coord = SkyCoord(lat * u.deg, lon * u.deg, frame=HeliographicStonyhurst) x, y = aiamap.wcs.world_to_pixel(coord) # (x, y) in pixels scale_x = 1024 / img_w scale_y = 1024 / img_h x, y = x / scale_x // patch_size, y / scale_y // patch_size if img_w > 1024 or img_h > 1024: warnings.warn( "Loss of precision when over 1024 on coordinate converstion, consider upgrading reference image." ) return torch.Tensor([x, y])
[docs] def patchify(imgs, patch_size, tubelet_size): """Convert image tensors into sequences of patches. Takes a 5D image tensor and reorganizes it into a sequence of flattened patches suitable for Vision Transformer (ViT) processing. Args: imgs (torch.Tensor): Input images of shape (B, C, T, H, W). patch_size (int): The spatial size of each square patch. tubelet_size (int): The temporal size of each tubelet. Returns: torch.Tensor: Patched tensor of shape (B, L, D) where L is the number of patches and D is the flattened patch dimension. """ p = patch_size tub = tubelet_size x = rearrange(imgs, "b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)", tub=tub, p=p, q=p) return x
[docs] def spatial_to_patch_mask( mask_2d: torch.Tensor, patch_size: int, num_frames: int = 1 ) -> torch.Tensor: """Convert 2D spatial mask to patch-level mask. Args: mask_2d: 2D binary mask of shape (H, W). patch_size: Spatial size of each patch. num_frames: Number of frames (temporal). Default: 1. Returns: torch.Tensor: 1D boolean tensor of shape (L,) where True = off-limb patch. """ H, W = mask_2d.shape p = patch_size h = H // p w = W // p mask_3d = mask_2d.unsqueeze(0).unsqueeze(0).expand(num_frames, 1, -1, -1) patches = rearrange(mask_3d, "(t c) (h p) (w q) -> (t h w) (p q c)", p=p, q=p) patch_is_zero = patches.sum(dim=(-1, -2)) == 0 return patch_is_zero
[docs] def unpatchify(x, img_size, patch_size, tubelet_size): """Reconstruct image tensors from sequences of patches. Takes a sequence of flattened patches and reorganizes them back into a 5D image tensor. Args: x (torch.Tensor): Patched tensor of shape (B, L, D). img_size (int): The spatial size of the original images (assumed square). patch_size (int): The spatial size of each patch. tubelet_size (int): The temporal size of each tubelet. Returns: torch.Tensor: Reconstructed images of shape (B, C, T, H, W). """ p = patch_size num_p = img_size // p tub = tubelet_size imgs = rearrange( x, "b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)", h=num_p, w=num_p, tub=tub, p=p, q=p, ) return imgs
[docs] def norm_target(target) -> torch.Tensor: """Normalize target values using z-score normalization. Applies z-score normalization to the target tensor along the last dimension, computing mean and variance per sample in the batch. Args: target (torch.Tensor): The input tensor to normalize. Returns: torch.Tensor: The normalized tensor with the same shape as input. """ mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 return target