Downstream App: Solar Wind

Data Module

class sdofmv2.tasks.solar_wind.datamodule.SWDataset(*args: Any, **kwargs: Any)[source]

Bases: SDOMLDataset

Solar Wind dataset for SDOML (Solar Dynamics Observatory Machine Learning).

This class extends the base SDOMLDataset to handle solar wind-specific features, including radial and latitudinal/longitudinal parameters. It supports temporal filtering by year/month, class-based undersampling for training sets, and automated column index mapping for coordinate features.

Parameters:
  • aligndata (pd.DataFrame) – DataFrame containing aligned timestamps, input data paths, and target labels.

  • hmi_data (zarr.Group) – SDO/HMI (magnetic field) data source.

  • aia_data (zarr.Group) – SDO/AIA (extreme ultraviolet) data source.

  • eve_data (zarr.Group) – SDO/EVE (irradiance) data source.

  • components (list[str]) – List of HMI magnetic field components to include.

  • wavelengths (list[int]) – List of AIA wavelengths (e.g., 171, 193) to include.

  • ions (list[str]) – List of EVE ions or spectral lines to include.

  • freq (str) – Data cadence string (e.g., ‘12min’ or ‘1h’).

  • months (list[int]) – List of months (1-12) to include in this data split.

  • years (list[int]) – List of years (e.g., 2010-2018) to include in this data split.

  • mask (torch.Tensor, optional) – A precomputed spatial mask to be applied to the AIA and HMI image data. Defaults to None.

  • num_frames (int) – Number of consecutive temporal frames per sample. Defaults to 1.

  • drop_frame_dim (bool) – If True, removes the temporal dimension (F) for single-frame samples. Defaults to False.

  • min_date (str, optional) – Minimum timestamp boundary for filtering data.

  • max_date (str, optional) – Maximum timestamp boundary for filtering data.

  • get_header (bool) – Whether to retrieve and return FITS headers with the data. Defaults to False.

  • normalization (dict) – Mapping of instrument keys to their respective normalization methods.

  • normalization_stat (dict) – Precomputed statistics (e.g., mean and standard deviation) used for data scaling.

  • label_type (str) – The column name in aligndata used as the prediction target.

  • radial_parameters (list[str], optional) – Column names representing radial distance features.

  • latlon_parameters (list[str], optional) – Column names representing spatial coordinates (latitude/longitude).

  • sampling_ratio (list[float], optional) – Fractions used to undersample specific classes (only applied when datasplit is “train”).

  • random_state (int, optional) – Seed for the random number generator to ensure reproducible undersampling.

  • datasplit (str) – The intended use of this instance; one of “train”, “val”, or “test”. Defaults to “train”.

radial_parameters

List of column names representing radial distance features.

Type:

list[str]

latlon_parameters

List of column names representing latitude and longitude features.

Type:

list[str]

aligndata

The filtered and potentially undersampled DataFrame containing alignment and label information.

Type:

pd.DataFrame

id_label

The integer column index of the target label within aligndata.

Type:

int

position_list

List of integer column indices corresponding to the latitudinal and longitudinal parameters.

Type:

list[int]

r_dist_list

List of integer column indices corresponding to the normalized radial distance parameters.

Type:

list[int]

class sdofmv2.tasks.solar_wind.datamodule.SWDataModule(*args: Any, **kwargs: Any)[source]

Bases: SDOMLDataModule

PyTorch Lightning DataModule for Solar Wind (SW) prediction.

This module handles the end-to-end data pipeline for SDOML datasets, including loading alignment indices from Zarr, filtering by temporal boundaries (years/months), applying spatial longitude cutoffs for solar footpoints, and managing normalization for radial distance parameters.

Parameters:
  • hmi_path (str) – Path to HMI Zarr data.

  • aia_path (str) – Path to AIA Zarr data.

  • eve_path (str) – Path to EVE Zarr data.

  • components (list[str]) – HMI magnetic components to use.

  • wavelengths (list[int]) – AIA wavelengths to use.

  • ions (list[str]) – EVE spectral lines to use.

  • frequency (str) – The sampling frequency of the instruments.

  • batch_size – Number of samples per batch. Defaults to 32.

  • num_workers – Number of subprocesses for data loading.

  • val_months – Months assigned to validation. Defaults to [10, 1].

  • test_months – Months assigned to testing. Defaults to [11, 12].

  • holdout_months – Months to be completely excluded from all splits.

  • radial_norm – Whether to apply Z-score normalization to radial features.

  • cache_dir – Directory for storing temporary or cached data.

  • apply_mask – If True, applies the limb mask to spatial data.

  • num_frames – Number of temporal frames per sample.

  • drop_frame_dim – Whether to squeeze the temporal dimension if it is 1.

  • min_date – Global start date boundary.

  • max_date – Global end date boundary.

  • precision – Numerical precision for tensors (“16”, “32”, or “64”).

  • normalization – Dictionary of normalization settings.

  • cfg – Configuration object for data cutoffs and model settings.

  • train_months – Months assigned to training.

  • train_years – Year(s) assigned to training.

  • val_years – Year(s) assigned to validation.

  • test_years – Year(s) assigned to testing.

  • alignment_indices_path (str) – Path to the Zarr file containing alignment indices and metadata.

  • radial_parameters (list[str]) – Column names for radial features.

  • latlon_parameters (list[str]) – Column names for coordinate features.

  • cadence – Temporal resolution string (e.g., “1min”).

  • label_type – The target column name for the model to predict.

  • sampling_ratio – Fraction of instances to sample per class.

  • random_state – Seed for reproducible sampling and shuffling.

aligndata

The central alignment table indexed by SDO observation time, containing indices of data and target labels.

Type:

pd.DataFrame

radial_mean

Mean value of the radial parameters used for normalization.

Type:

float

radial_std

Standard deviation of the radial parameters.

Type:

float

train_years

Years allocated for the training set.

Type:

int | list[int]

val_years

Years allocated for the validation set.

Type:

int | list[int]

test_years

Years allocated for the test set.

Type:

int | list[int]

cfg

Configuration object containing hyperparameters and data cutoffs (e.g., cfg.data.in_situ.lon_cutoff).

Type:

DictConfig | Any

Methods

predict_dataloader()

setup([stage])

test_dataloader()

train_dataloader()

val_dataloader()

Head Module

class sdofmv2.tasks.solar_wind.head_networks.TransformerHead(d_output, input_token_dim=512, p_drop=0.1, max_position_element=4, num_pos_token=10, nhead=8)[source]

Bases: Module

Transformer-based classification head with coordinate-to-token projection.

This module converts physical position and radial distance into a set of learned positional tokens. These tokens are prepended to the input sequence (alongside the CLS token) and processed through a Transformer Encoder block.

Parameters:
  • d_output (int) – Dimension of the final output (e.g., number of classes).

  • input_token_dim (int) – Embedding dimension (D). Defaults to 512.

  • p_drop (float) – Dropout probability. Defaults to 0.1.

  • max_position_element (int) – Highest power used in sine/cosine transform. Defaults to 4.

  • num_pos_token (int) – Number of positional tokens to inject. Defaults to 10.

  • nhead (int) – Number of attention heads. Defaults to 8.

num_pos_token

The number of latent tokens generated from the positional encoding.

Type:

int

input_token_dim

The dimensionality of the latent tokens (D).

Type:

int

pos_encoder

A function applying sine/cosine transformations to raw coordinates.

Type:

callable

projection

Linear layer mapping encoded positions to the token space.

Type:

nn.Linear

transformer_block

A single transformer layer for cross-token communication.

Type:

nn.TransformerEncoderLayer

classifier

Final mapping to output dimension.

Type:

nn.Linear

Methods

forward(x, position, r_distance)

Perform a forward pass through the transformer head.

forward(x, position, r_distance)[source]

Perform a forward pass through the transformer head.

Parameters:
  • x (torch.Tensor) – Input features of shape (batch, num_tokens, embed_dim).

  • position (torch.Tensor) – Position coordinates of shape (batch, position_size).

  • r_distance (torch.Tensor) – Radial distance values of shape (batch,).

Returns:

Logits of shape (batch, d_output) for classification/regression.

Return type:

torch.Tensor

class sdofmv2.tasks.solar_wind.head_networks.SimpleLinear(d_output, input_feature_dim, max_position_element=4, position_size=4, hidden_dim=16, p_drop=0.1)[source]

Bases: Module

Multi-Layer Perceptron head for flattened feature processing.

This head flattens the input feature map and concatenates it with harmonic positional encodings (sine/cosine) and radial distance before passing the combined vector through a non-linear MLP.

Parameters:
  • d_output (int) – Dimension of the final output.

  • input_feature_dim (int) – Dimension of the input features before flattening.

  • max_position_element (int) – Highest power used in sine/cosine transform. Defaults to 4.

  • position_size (int) – Number of raw coordinate variables. Defaults to 4.

  • hidden_dim (int) – Width of the first hidden layer. Defaults to 16.

  • p_drop (float) – Dropout probability. Defaults to 0.1.

d_output

Output dimensionality.

Type:

int

input_feature_dim

Dimension of the input features.

Type:

int

hidden_dim

Hidden layer width.

Type:

int

p_drop

Dropout rate used in the network.

Type:

float

max_position_element

Complexity of the harmonic encoding.

Type:

int

position_size

Number of coordinate inputs.

Type:

int

network

The MLP architecture (Linear -> ReLU -> Dropout).

Type:

nn.Sequential

Methods

forward(x, position, r_distance)

Perform a forward pass through the MLP head.

forward(x, position, r_distance)[source]

Perform a forward pass through the MLP head.

Parameters:
  • x (torch.Tensor) – Input features of shape (batch, …).

  • position (torch.Tensor) – Position coordinates of shape (batch, position_size).

  • r_distance (torch.Tensor) – Radial distance values of shape (batch,).

Returns:

Output logits of shape (batch, d_output).

Return type:

torch.Tensor

class sdofmv2.tasks.solar_wind.head_networks.SkipLinearHead(d_output, input_feature_dim, max_position_element=4, position_size=4, hidden_dim=16, skips=[4], include_raw_coordinates=False, num_hidden_layers=8, number_of_frames=1)[source]

Bases: Module

Deep MLP head with skip connections for coordinate-aware regression.

This module implements a deep architecture where the initial concatenated input (features + encodings) is re-injected at specified layers. This residual-style connection helps maintain high-frequency coordinate information throughout the depth of the network.

Parameters:
  • d_output (int) – Dimension of the final output.

  • input_feature_dim (int) – Dimension of a single input frame.

  • max_position_element (int) – Highest power for harmonic encoding. Defaults to 4.

  • position_size (int) – Number of raw coordinate variables. Defaults to 4.

  • hidden_dim (int) – Latent width of the hidden layers. Defaults to 16.

  • skips (list[int]) – Layer indices where the initial input is concatenated back into the hidden state. Defaults to [4].

  • include_raw_coordinates (bool) – If True, appends non-encoded coordinates to the input vector. Defaults to False.

  • num_hidden_layers (int) – Total number of linear layers in the backbone. Defaults to 8.

  • number_of_frames (int) – Number of temporal frames to flatten. Defaults to 1.

d_output

Output dimensionality.

Type:

int

hidden_dim

Latent width of the network.

Type:

int

max_position_element

Harmonic encoding complexity.

Type:

int

position_size

Number of coordinate inputs.

Type:

int

skips

Indices of layers performing skip-connections.

Type:

list[int]

num_hidden_layers

Total depth of the MLP.

Type:

int

pts_linears

Collection of linear layers including skip-connection logic.

Type:

nn.ModuleList

output_linear

Final layer mapping to output dimension.

Type:

nn.Linear

Methods

forward(x, position, r_distance)

Perform a forward pass through the skip-connected MLP head.

forward(x, position, r_distance)[source]

Perform a forward pass through the skip-connected MLP head.

Parameters:
  • x (torch.Tensor) – Input features of shape (batch, …).

  • position (torch.Tensor) – Position coordinates of shape (batch, position_size).

  • r_distance (torch.Tensor) – Radial distance values of shape (batch,).

Returns:

Output logits of shape (batch, d_output).

Return type:

torch.Tensor

class sdofmv2.tasks.solar_wind.head_networks.ClsLinear(d_output, embedding_dim, max_position_element=4, position_size=4, hidden_dim=16, p_drop=0.1)[source]

Bases: Module

MLP head designed for Transformer [CLS] token representations.

This module specifically extracts the class (CLS) token from the first index of the input sequence and combines it with physical metadata (positional encodings and radial distance) for final prediction.

Parameters:
  • d_output (int) – Dimension of the final output.

  • embedding_dim (int) – Dimensionality of the tokens in the input sequence.

  • max_position_element (int) – Highest power for harmonic encoding. Defaults to 4.

  • position_size (int) – Number of raw coordinate variables. Defaults to 4.

  • hidden_dim (int) – Hidden width of the MLP. Defaults to 16.

  • p_drop (float) – Dropout probability. Defaults to 0.1.

d_output

Output dimensionality.

Type:

int

embedding_dim

Dimension of the input tokens (D).

Type:

int

hidden_dim

Hidden layer width.

Type:

int

p_drop

Dropout rate.

Type:

float

max_position_element

Complexity of the harmonic encoding.

Type:

int

position_size

Number of coordinate inputs.

Type:

int

network

MLP layers processing the combined CLS and metadata vector.

Type:

nn.Sequential

Methods

forward(x, position, r_distance)

Perform a forward pass using the CLS token.

forward(x, position, r_distance)[source]

Perform a forward pass using the CLS token.

Parameters:
  • x (torch.Tensor) – Input sequence with CLS token at index 0, shape (batch, num_tokens, embed_dim).

  • position (torch.Tensor) – Position coordinates of shape (batch, position_size).

  • r_distance (torch.Tensor) – Radial distance values of shape (batch,).

Returns:

Output logits of shape (batch, d_output).

Return type:

torch.Tensor

Model Module

class sdofmv2.tasks.solar_wind.model.SWClassifier(*args: Any, **kwargs: Any)[source]

Bases: BaseModule

Solar Wind Classifier using a backbone encoder and configurable head.

This module wraps a pretrained backbone encoder and adds a classification head for solar wind prediction tasks. It supports multiple head types (linear, transformer, skip_linear), handles coordinate embeddings, and tracks various classification metrics during training, validation, and testing.

Parameters:
  • channels (list[str]) – List of data channels to use.

  • num_classes (int) – Number of output classes.

  • class_names (list[str]) – Names of the classes for logging.

  • max_position_element (int) – Maximum power for positional encoding.

  • backbone – The pretrained backbone model.

  • freeze_encoder (bool) – Whether to freeze the backbone encoder.

  • plt_style – Plotting style configuration.

  • head_type (str) – Type of classification head (“linear”, “transformer”, “skip_linear”).

  • hidden_dim (int) – Hidden dimension for the head network.

  • position_size (int) – Number of position coordinates.

  • p_drop (float) – Dropout probability.

  • nhead (int) – Number of attention heads for transformer head.

  • embed_dim (int) – Embedding dimension.

  • skips (list[int]) – Layer indices for skip connections.

  • include_raw_coordinates (bool) – Whether to include raw coordinates.

  • num_hidden_layers (int) – Number of hidden layers for skip_linear head.

  • radial_mean (float) – Mean for radial normalization.

  • radial_std (float) – Standard deviation for radial normalization.

  • loss_dict (dict) – Loss function configuration.

  • optimizer_dict (dict) – Optimizer configuration.

  • scheduler_dict (dict) – Scheduler configuration.

Methods

forward(x, position, r_distance)

Perform a forward pass through the classifier.

forward_analysis(x)

Perform analysis forward pass for reconstruction visualization.

on_before_optimizer_step(optimizer)

Called before each optimizer step to log gradient norms.

on_test_epoch_end()

Called at the end of the test epoch.

on_train_epoch_end()

Called at the end of the training epoch.

on_validation_epoch_end()

Called at the end of the validation epoch.

predict_step(batch, batch_idx)

Perform a prediction step for inference.

test_step(batch, batch_idx)

Perform a single test step.

training_step(batch, batch_idx)

Perform a single training step.

validation_step(batch, batch_idx)

Perform a single validation step.

forward(x, position, r_distance)[source]

Perform a forward pass through the classifier.

Parameters:
  • x (torch.Tensor) – Input images of shape (B, C, H, W).

  • position (torch.Tensor) – Position coordinates of shape (B, position_size).

  • r_distance (torch.Tensor) – Radial distance values of shape (B,).

Returns:

Class logits of shape (B, num_classes).

Return type:

torch.Tensor

forward_analysis(x)[source]

Perform analysis forward pass for reconstruction visualization.

Parameters:

x (torch.Tensor) – Input images of shape (B, C, H, W).

Returns:

Reconstructed images.

Return type:

torch.Tensor

on_before_optimizer_step(optimizer)[source]

Called before each optimizer step to log gradient norms.

Parameters:

optimizer – The optimizer about to perform an update step.

on_test_epoch_end()[source]

Called at the end of the test epoch.

Computes and logs all test metrics including F1, precision, recall, accuracy, AUROC, MCC, and Cohen’s Kappa.

on_train_epoch_end()[source]

Called at the end of the training epoch.

Performs garbage collection and clears CUDA cache to free memory.

on_validation_epoch_end()[source]

Called at the end of the validation epoch.

Computes and logs all validation metrics, generates WandB plots (confusion matrix, class distributions), and clears stored buffers.

predict_step(batch, batch_idx)[source]

Perform a prediction step for inference.

Parameters:
  • batch – A tuple containing (images, timestamps, position, r_distance, target).

  • batch_idx – The index of the current batch.

Returns:

A dictionary containing predictions, targets, embeddings, etc.

Return type:

dict

test_step(batch, batch_idx)[source]

Perform a single test step.

Parameters:
  • batch – A tuple containing (images, timestamps, position, r_distance, target).

  • batch_idx – The index of the current batch.

Returns:

A dictionary containing predictions, targets, logits, and test loss.

Return type:

dict

training_step(batch, batch_idx)[source]

Perform a single training step.

Parameters:
  • batch – A tuple containing (images, timestamps, position, r_distance, target).

  • batch_idx – The index of the current batch.

Returns:

The training loss value.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Perform a single validation step.

Parameters:
  • batch – A tuple containing (images, timestamps, position, r_distance, target).

  • batch_idx – The index of the current batch.

Returns:

The validation loss value.

Return type:

torch.Tensor

Focal Loss

sdofmv2.tasks.solar_wind.focal_loss.focal_loss_multiclass(inputs: Tensor, targets: Tensor, alpha: float | Tensor = 0.25, gamma: float = 2.0, reduction: str = 'none') Tensor[source]

Multi-class focal loss based on torchvision.ops.sigmoid_focal_loss.

Parameters:
  • inputs (Tensor[N, C]) – Logits for each class.

  • targets (Tensor[N]) – Class indices (0 ≤ targets < C).

  • alpha (float or Tensor[C]) – Balance factor(s). Scalar or per-class.

  • gamma (float) – Modulating factor exponent.

  • reduction (str) – ‘none’, ‘mean’, or ‘sum’.

Returns:

Loss per sample, or reduced loss.

Return type:

Tensor

Visualization

sdofmv2.tasks.solar_wind.visualization.colorbar_list(AIA_wavelengths, HMI_components)[source]
sdofmv2.tasks.solar_wind.visualization.find_images_labels(imgs, timestamps, targets, preds, position, class_id, ch_id, max_samples=3)[source]
Parameters:
  • imgs – Input tensors

  • targets – Input tensors

  • preds – Input tensors

  • position – Input tensors

  • class_id – Class to filter for

  • max_samples – Maximum number of samples to return (default: 3)

sdofmv2.tasks.solar_wind.visualization.find_images_labels_embed(imgs, timestamps, targets, preds, position, class_id, max_samples=3)[source]
Parameters:
  • imgs – Input tensors

  • targets – Input tensors

  • preds – Input tensors

  • position – Input tensors

  • class_id – Class to filter for

  • max_samples – Maximum number of samples to return (default: 3)

sdofmv2.tasks.solar_wind.visualization.plot_disk_distribution(img, model, time, wsa_lat=array([-75., -69.82758621, -64.65517241, -59.48275862, -54.31034483, -49.13793103, -43.96551724, -38.79310345, -33.62068966, -28.44827586, -23.27586207, -18.10344828, -12.93103448, -7.75862069, -2.5862069, 2.5862069, 7.75862069, 12.93103448, 18.10344828, 23.27586207, 28.44827586, 33.62068966, 38.79310345, 43.96551724, 49.13793103, 54.31034483, 59.48275862, 64.65517241, 69.82758621, 75.]), wsa_lon=array([-90., -83.79310345, -77.5862069, -71.37931034, -65.17241379, -58.96551724, -52.75862069, -46.55172414, -40.34482759, -34.13793103, -27.93103448, -21.72413793, -15.51724138, -9.31034483, -3.10344828, 3.10344828, 9.31034483, 15.51724138, 21.72413793, 27.93103448, 34.13793103, 40.34482759, 46.55172414, 52.75862069, 58.96551724, 65.17241379, 71.37931034, 77.5862069, 83.79310345, 90.]), labels=['Ejecta', 'Coronal Hole', 'Sector Reversal', 'Streamer Belt'], colors=['#191923', '#0E79B2', '#bf1363', '#F39237'])[source]
sdofmv2.tasks.solar_wind.visualization.plot_ecliptic(img, model, time, x_range=array([-213., -198.31034483, -183.62068966, -168.93103448, -154.24137931, -139.55172414, -124.86206897, -110.17241379, -95.48275862, -80.79310345, -66.10344828, -51.4137931, -36.72413793, -22.03448276, -7.34482759, 7.34482759, 22.03448276, 36.72413793, 51.4137931, 66.10344828, 80.79310345, 95.48275862, 110.17241379, 124.86206897, 139.55172414, 154.24137931, 168.93103448, 183.62068966, 198.31034483, 213.]), y_range=array([-213., -198.31034483, -183.62068966, -168.93103448, -154.24137931, -139.55172414, -124.86206897, -110.17241379, -95.48275862, -80.79310345, -66.10344828, -51.4137931, -36.72413793, -22.03448276, -7.34482759, 7.34482759, 22.03448276, 36.72413793, 51.4137931, 66.10344828, 80.79310345, 95.48275862, 110.17241379, 124.86206897, 139.55172414, 154.24137931, 168.93103448, 183.62068966, 198.31034483, 213.]), z_range=array([0.]), r_mean=np.float32(8.918359e+07), r_std=np.float32(3.0130634e+07), inner_mask=10, outer_mask=220)[source]
sdofmv2.tasks.solar_wind.visualization.plot_images_grid(correct_data, plt_style, step=None)[source]
sdofmv2.tasks.solar_wind.visualization.plot_sdoml(datamodule, condition=None, dataset_idx=None, times=None, n_samples=None, wavelengths=['131A', '1600A', '1700A', '171A', '193A', '211A', '304A', '335A', '94A'], components=['Bx', 'By', 'Bz'], wsa_footpoint=False, title=None)[source]
sdofmv2.tasks.solar_wind.visualization.plot_tsne(embs, position, r_distance, labels, perplexities=[5, 30, 50, 100, 300], classes=['Ejecta', 'Coronal Hole', 'Sector Reversal', 'Streamer Belt'], colors=['#191923', '#0E79B2', '#bf1363', '#F39237'])[source]
sdofmv2.tasks.solar_wind.visualization.wsa_to_image(lat_sh, lon_sh, r_sun=200, ctr_pxl=256)[source]

Returns position on SDOML image of PSP footpoint from WSA. lat_sh and lon_sh are Stonyhurst latitude and longitude in degrees. r_sun is radius of sun on image in units of pixels. ctr_pixel is the value of the center of each axis in units of pixels.