Downstream App: Solar Wind
Data Module
- class sdofmv2.tasks.solar_wind.datamodule.SWDataset(*args: Any, **kwargs: Any)[source]
Bases:
SDOMLDatasetSolar Wind dataset for SDOML (Solar Dynamics Observatory Machine Learning).
This class extends the base SDOMLDataset to handle solar wind-specific features, including radial and latitudinal/longitudinal parameters. It supports temporal filtering by year/month, class-based undersampling for training sets, and automated column index mapping for coordinate features.
- Parameters:
aligndata (pd.DataFrame) – DataFrame containing aligned timestamps, input data paths, and target labels.
hmi_data (zarr.Group) – SDO/HMI (magnetic field) data source.
aia_data (zarr.Group) – SDO/AIA (extreme ultraviolet) data source.
eve_data (zarr.Group) – SDO/EVE (irradiance) data source.
components (list[str]) – List of HMI magnetic field components to include.
wavelengths (list[int]) – List of AIA wavelengths (e.g., 171, 193) to include.
ions (list[str]) – List of EVE ions or spectral lines to include.
freq (str) – Data cadence string (e.g., ‘12min’ or ‘1h’).
months (list[int]) – List of months (1-12) to include in this data split.
years (list[int]) – List of years (e.g., 2010-2018) to include in this data split.
mask (torch.Tensor, optional) – A precomputed spatial mask to be applied to the AIA and HMI image data. Defaults to None.
num_frames (int) – Number of consecutive temporal frames per sample. Defaults to 1.
drop_frame_dim (bool) – If True, removes the temporal dimension (F) for single-frame samples. Defaults to False.
min_date (str, optional) – Minimum timestamp boundary for filtering data.
max_date (str, optional) – Maximum timestamp boundary for filtering data.
get_header (bool) – Whether to retrieve and return FITS headers with the data. Defaults to False.
normalization (dict) – Mapping of instrument keys to their respective normalization methods.
normalization_stat (dict) – Precomputed statistics (e.g., mean and standard deviation) used for data scaling.
label_type (str) – The column name in
aligndataused as the prediction target.radial_parameters (list[str], optional) – Column names representing radial distance features.
latlon_parameters (list[str], optional) – Column names representing spatial coordinates (latitude/longitude).
sampling_ratio (list[float], optional) – Fractions used to undersample specific classes (only applied when
datasplitis “train”).random_state (int, optional) – Seed for the random number generator to ensure reproducible undersampling.
datasplit (str) – The intended use of this instance; one of “train”, “val”, or “test”. Defaults to “train”.
- radial_parameters
List of column names representing radial distance features.
- Type:
list[str]
- latlon_parameters
List of column names representing latitude and longitude features.
- Type:
list[str]
- aligndata
The filtered and potentially undersampled DataFrame containing alignment and label information.
- Type:
pd.DataFrame
- id_label
The integer column index of the target label within
aligndata.- Type:
int
- position_list
List of integer column indices corresponding to the latitudinal and longitudinal parameters.
- Type:
list[int]
- r_dist_list
List of integer column indices corresponding to the normalized radial distance parameters.
- Type:
list[int]
- class sdofmv2.tasks.solar_wind.datamodule.SWDataModule(*args: Any, **kwargs: Any)[source]
Bases:
SDOMLDataModulePyTorch Lightning DataModule for Solar Wind (SW) prediction.
This module handles the end-to-end data pipeline for SDOML datasets, including loading alignment indices from Zarr, filtering by temporal boundaries (years/months), applying spatial longitude cutoffs for solar footpoints, and managing normalization for radial distance parameters.
- Parameters:
hmi_path (str) – Path to HMI Zarr data.
aia_path (str) – Path to AIA Zarr data.
eve_path (str) – Path to EVE Zarr data.
components (list[str]) – HMI magnetic components to use.
wavelengths (list[int]) – AIA wavelengths to use.
ions (list[str]) – EVE spectral lines to use.
frequency (str) – The sampling frequency of the instruments.
batch_size – Number of samples per batch. Defaults to 32.
num_workers – Number of subprocesses for data loading.
val_months – Months assigned to validation. Defaults to [10, 1].
test_months – Months assigned to testing. Defaults to [11, 12].
holdout_months – Months to be completely excluded from all splits.
radial_norm – Whether to apply Z-score normalization to radial features.
cache_dir – Directory for storing temporary or cached data.
apply_mask – If True, applies the limb mask to spatial data.
num_frames – Number of temporal frames per sample.
drop_frame_dim – Whether to squeeze the temporal dimension if it is 1.
min_date – Global start date boundary.
max_date – Global end date boundary.
precision – Numerical precision for tensors (“16”, “32”, or “64”).
normalization – Dictionary of normalization settings.
cfg – Configuration object for data cutoffs and model settings.
train_months – Months assigned to training.
train_years – Year(s) assigned to training.
val_years – Year(s) assigned to validation.
test_years – Year(s) assigned to testing.
alignment_indices_path (str) – Path to the Zarr file containing alignment indices and metadata.
radial_parameters (list[str]) – Column names for radial features.
latlon_parameters (list[str]) – Column names for coordinate features.
cadence – Temporal resolution string (e.g., “1min”).
label_type – The target column name for the model to predict.
sampling_ratio – Fraction of instances to sample per class.
random_state – Seed for reproducible sampling and shuffling.
- aligndata
The central alignment table indexed by SDO observation time, containing indices of data and target labels.
- Type:
pd.DataFrame
- radial_mean
Mean value of the radial parameters used for normalization.
- Type:
float
- radial_std
Standard deviation of the radial parameters.
- Type:
float
- train_years
Years allocated for the training set.
- Type:
int | list[int]
- val_years
Years allocated for the validation set.
- Type:
int | list[int]
- test_years
Years allocated for the test set.
- Type:
int | list[int]
- cfg
Configuration object containing hyperparameters and data cutoffs (e.g.,
cfg.data.in_situ.lon_cutoff).- Type:
DictConfig | Any
Methods
predict_dataloader()setup([stage])test_dataloader()train_dataloader()val_dataloader()
Head Module
- class sdofmv2.tasks.solar_wind.head_networks.TransformerHead(d_output, input_token_dim=512, p_drop=0.1, max_position_element=4, num_pos_token=10, nhead=8)[source]
Bases:
ModuleTransformer-based classification head with coordinate-to-token projection.
This module converts physical position and radial distance into a set of learned positional tokens. These tokens are prepended to the input sequence (alongside the CLS token) and processed through a Transformer Encoder block.
- Parameters:
d_output (int) – Dimension of the final output (e.g., number of classes).
input_token_dim (int) – Embedding dimension (D). Defaults to 512.
p_drop (float) – Dropout probability. Defaults to 0.1.
max_position_element (int) – Highest power used in sine/cosine transform. Defaults to 4.
num_pos_token (int) – Number of positional tokens to inject. Defaults to 10.
nhead (int) – Number of attention heads. Defaults to 8.
- num_pos_token
The number of latent tokens generated from the positional encoding.
- Type:
int
- input_token_dim
The dimensionality of the latent tokens (D).
- Type:
int
- pos_encoder
A function applying sine/cosine transformations to raw coordinates.
- Type:
callable
- projection
Linear layer mapping encoded positions to the token space.
- Type:
nn.Linear
- transformer_block
A single transformer layer for cross-token communication.
- Type:
nn.TransformerEncoderLayer
- classifier
Final mapping to output dimension.
- Type:
nn.Linear
Methods
forward(x, position, r_distance)Perform a forward pass through the transformer head.
- forward(x, position, r_distance)[source]
Perform a forward pass through the transformer head.
- Parameters:
x (torch.Tensor) – Input features of shape (batch, num_tokens, embed_dim).
position (torch.Tensor) – Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor) – Radial distance values of shape (batch,).
- Returns:
Logits of shape (batch, d_output) for classification/regression.
- Return type:
torch.Tensor
- class sdofmv2.tasks.solar_wind.head_networks.SimpleLinear(d_output, input_feature_dim, max_position_element=4, position_size=4, hidden_dim=16, p_drop=0.1)[source]
Bases:
ModuleMulti-Layer Perceptron head for flattened feature processing.
This head flattens the input feature map and concatenates it with harmonic positional encodings (sine/cosine) and radial distance before passing the combined vector through a non-linear MLP.
- Parameters:
d_output (int) – Dimension of the final output.
input_feature_dim (int) – Dimension of the input features before flattening.
max_position_element (int) – Highest power used in sine/cosine transform. Defaults to 4.
position_size (int) – Number of raw coordinate variables. Defaults to 4.
hidden_dim (int) – Width of the first hidden layer. Defaults to 16.
p_drop (float) – Dropout probability. Defaults to 0.1.
- d_output
Output dimensionality.
- Type:
int
- input_feature_dim
Dimension of the input features.
- Type:
int
Hidden layer width.
- Type:
int
- p_drop
Dropout rate used in the network.
- Type:
float
- max_position_element
Complexity of the harmonic encoding.
- Type:
int
- position_size
Number of coordinate inputs.
- Type:
int
- network
The MLP architecture (Linear -> ReLU -> Dropout).
- Type:
nn.Sequential
Methods
forward(x, position, r_distance)Perform a forward pass through the MLP head.
- forward(x, position, r_distance)[source]
Perform a forward pass through the MLP head.
- Parameters:
x (torch.Tensor) – Input features of shape (batch, …).
position (torch.Tensor) – Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor) – Radial distance values of shape (batch,).
- Returns:
Output logits of shape (batch, d_output).
- Return type:
torch.Tensor
- class sdofmv2.tasks.solar_wind.head_networks.SkipLinearHead(d_output, input_feature_dim, max_position_element=4, position_size=4, hidden_dim=16, skips=[4], include_raw_coordinates=False, num_hidden_layers=8, number_of_frames=1)[source]
Bases:
ModuleDeep MLP head with skip connections for coordinate-aware regression.
This module implements a deep architecture where the initial concatenated input (features + encodings) is re-injected at specified layers. This residual-style connection helps maintain high-frequency coordinate information throughout the depth of the network.
- Parameters:
d_output (int) – Dimension of the final output.
input_feature_dim (int) – Dimension of a single input frame.
max_position_element (int) – Highest power for harmonic encoding. Defaults to 4.
position_size (int) – Number of raw coordinate variables. Defaults to 4.
hidden_dim (int) – Latent width of the hidden layers. Defaults to 16.
skips (list[int]) – Layer indices where the initial input is concatenated back into the hidden state. Defaults to [4].
include_raw_coordinates (bool) – If True, appends non-encoded coordinates to the input vector. Defaults to False.
num_hidden_layers (int) – Total number of linear layers in the backbone. Defaults to 8.
number_of_frames (int) – Number of temporal frames to flatten. Defaults to 1.
- d_output
Output dimensionality.
- Type:
int
Latent width of the network.
- Type:
int
- max_position_element
Harmonic encoding complexity.
- Type:
int
- position_size
Number of coordinate inputs.
- Type:
int
- skips
Indices of layers performing skip-connections.
- Type:
list[int]
Total depth of the MLP.
- Type:
int
- pts_linears
Collection of linear layers including skip-connection logic.
- Type:
nn.ModuleList
- output_linear
Final layer mapping to output dimension.
- Type:
nn.Linear
Methods
forward(x, position, r_distance)Perform a forward pass through the skip-connected MLP head.
- forward(x, position, r_distance)[source]
Perform a forward pass through the skip-connected MLP head.
- Parameters:
x (torch.Tensor) – Input features of shape (batch, …).
position (torch.Tensor) – Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor) – Radial distance values of shape (batch,).
- Returns:
Output logits of shape (batch, d_output).
- Return type:
torch.Tensor
- class sdofmv2.tasks.solar_wind.head_networks.ClsLinear(d_output, embedding_dim, max_position_element=4, position_size=4, hidden_dim=16, p_drop=0.1)[source]
Bases:
ModuleMLP head designed for Transformer [CLS] token representations.
This module specifically extracts the class (CLS) token from the first index of the input sequence and combines it with physical metadata (positional encodings and radial distance) for final prediction.
- Parameters:
d_output (int) – Dimension of the final output.
embedding_dim (int) – Dimensionality of the tokens in the input sequence.
max_position_element (int) – Highest power for harmonic encoding. Defaults to 4.
position_size (int) – Number of raw coordinate variables. Defaults to 4.
hidden_dim (int) – Hidden width of the MLP. Defaults to 16.
p_drop (float) – Dropout probability. Defaults to 0.1.
- d_output
Output dimensionality.
- Type:
int
- embedding_dim
Dimension of the input tokens (D).
- Type:
int
Hidden layer width.
- Type:
int
- p_drop
Dropout rate.
- Type:
float
- max_position_element
Complexity of the harmonic encoding.
- Type:
int
- position_size
Number of coordinate inputs.
- Type:
int
- network
MLP layers processing the combined CLS and metadata vector.
- Type:
nn.Sequential
Methods
forward(x, position, r_distance)Perform a forward pass using the CLS token.
- forward(x, position, r_distance)[source]
Perform a forward pass using the CLS token.
- Parameters:
x (torch.Tensor) – Input sequence with CLS token at index 0, shape (batch, num_tokens, embed_dim).
position (torch.Tensor) – Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor) – Radial distance values of shape (batch,).
- Returns:
Output logits of shape (batch, d_output).
- Return type:
torch.Tensor
Model Module
- class sdofmv2.tasks.solar_wind.model.SWClassifier(*args: Any, **kwargs: Any)[source]
Bases:
BaseModuleSolar Wind Classifier using a backbone encoder and configurable head.
This module wraps a pretrained backbone encoder and adds a classification head for solar wind prediction tasks. It supports multiple head types (linear, transformer, skip_linear), handles coordinate embeddings, and tracks various classification metrics during training, validation, and testing.
- Parameters:
channels (list[str]) – List of data channels to use.
num_classes (int) – Number of output classes.
class_names (list[str]) – Names of the classes for logging.
max_position_element (int) – Maximum power for positional encoding.
backbone – The pretrained backbone model.
freeze_encoder (bool) – Whether to freeze the backbone encoder.
plt_style – Plotting style configuration.
head_type (str) – Type of classification head (“linear”, “transformer”, “skip_linear”).
hidden_dim (int) – Hidden dimension for the head network.
position_size (int) – Number of position coordinates.
p_drop (float) – Dropout probability.
nhead (int) – Number of attention heads for transformer head.
embed_dim (int) – Embedding dimension.
skips (list[int]) – Layer indices for skip connections.
include_raw_coordinates (bool) – Whether to include raw coordinates.
num_hidden_layers (int) – Number of hidden layers for skip_linear head.
radial_mean (float) – Mean for radial normalization.
radial_std (float) – Standard deviation for radial normalization.
loss_dict (dict) – Loss function configuration.
optimizer_dict (dict) – Optimizer configuration.
scheduler_dict (dict) – Scheduler configuration.
Methods
forward(x, position, r_distance)Perform a forward pass through the classifier.
Perform analysis forward pass for reconstruction visualization.
on_before_optimizer_step(optimizer)Called before each optimizer step to log gradient norms.
Called at the end of the test epoch.
Called at the end of the training epoch.
Called at the end of the validation epoch.
predict_step(batch, batch_idx)Perform a prediction step for inference.
test_step(batch, batch_idx)Perform a single test step.
training_step(batch, batch_idx)Perform a single training step.
validation_step(batch, batch_idx)Perform a single validation step.
- forward(x, position, r_distance)[source]
Perform a forward pass through the classifier.
- Parameters:
x (torch.Tensor) – Input images of shape (B, C, H, W).
position (torch.Tensor) – Position coordinates of shape (B, position_size).
r_distance (torch.Tensor) – Radial distance values of shape (B,).
- Returns:
Class logits of shape (B, num_classes).
- Return type:
torch.Tensor
- forward_analysis(x)[source]
Perform analysis forward pass for reconstruction visualization.
- Parameters:
x (torch.Tensor) – Input images of shape (B, C, H, W).
- Returns:
Reconstructed images.
- Return type:
torch.Tensor
- on_before_optimizer_step(optimizer)[source]
Called before each optimizer step to log gradient norms.
- Parameters:
optimizer – The optimizer about to perform an update step.
- on_test_epoch_end()[source]
Called at the end of the test epoch.
Computes and logs all test metrics including F1, precision, recall, accuracy, AUROC, MCC, and Cohen’s Kappa.
- on_train_epoch_end()[source]
Called at the end of the training epoch.
Performs garbage collection and clears CUDA cache to free memory.
- on_validation_epoch_end()[source]
Called at the end of the validation epoch.
Computes and logs all validation metrics, generates WandB plots (confusion matrix, class distributions), and clears stored buffers.
- predict_step(batch, batch_idx)[source]
Perform a prediction step for inference.
- Parameters:
batch – A tuple containing (images, timestamps, position, r_distance, target).
batch_idx – The index of the current batch.
- Returns:
A dictionary containing predictions, targets, embeddings, etc.
- Return type:
dict
- test_step(batch, batch_idx)[source]
Perform a single test step.
- Parameters:
batch – A tuple containing (images, timestamps, position, r_distance, target).
batch_idx – The index of the current batch.
- Returns:
A dictionary containing predictions, targets, logits, and test loss.
- Return type:
dict
Focal Loss
- sdofmv2.tasks.solar_wind.focal_loss.focal_loss_multiclass(inputs: Tensor, targets: Tensor, alpha: float | Tensor = 0.25, gamma: float = 2.0, reduction: str = 'none') Tensor[source]
Multi-class focal loss based on torchvision.ops.sigmoid_focal_loss.
- Parameters:
inputs (Tensor[N, C]) – Logits for each class.
targets (Tensor[N]) – Class indices (0 ≤ targets < C).
alpha (float or Tensor[C]) – Balance factor(s). Scalar or per-class.
gamma (float) – Modulating factor exponent.
reduction (str) – ‘none’, ‘mean’, or ‘sum’.
- Returns:
Loss per sample, or reduced loss.
- Return type:
Tensor
Visualization
- sdofmv2.tasks.solar_wind.visualization.find_images_labels(imgs, timestamps, targets, preds, position, class_id, ch_id, max_samples=3)[source]
- Parameters:
imgs – Input tensors
targets – Input tensors
preds – Input tensors
position – Input tensors
class_id – Class to filter for
max_samples – Maximum number of samples to return (default: 3)
- sdofmv2.tasks.solar_wind.visualization.find_images_labels_embed(imgs, timestamps, targets, preds, position, class_id, max_samples=3)[source]
- Parameters:
imgs – Input tensors
targets – Input tensors
preds – Input tensors
position – Input tensors
class_id – Class to filter for
max_samples – Maximum number of samples to return (default: 3)
- sdofmv2.tasks.solar_wind.visualization.plot_disk_distribution(img, model, time, wsa_lat=array([-75., -69.82758621, -64.65517241, -59.48275862, -54.31034483, -49.13793103, -43.96551724, -38.79310345, -33.62068966, -28.44827586, -23.27586207, -18.10344828, -12.93103448, -7.75862069, -2.5862069, 2.5862069, 7.75862069, 12.93103448, 18.10344828, 23.27586207, 28.44827586, 33.62068966, 38.79310345, 43.96551724, 49.13793103, 54.31034483, 59.48275862, 64.65517241, 69.82758621, 75.]), wsa_lon=array([-90., -83.79310345, -77.5862069, -71.37931034, -65.17241379, -58.96551724, -52.75862069, -46.55172414, -40.34482759, -34.13793103, -27.93103448, -21.72413793, -15.51724138, -9.31034483, -3.10344828, 3.10344828, 9.31034483, 15.51724138, 21.72413793, 27.93103448, 34.13793103, 40.34482759, 46.55172414, 52.75862069, 58.96551724, 65.17241379, 71.37931034, 77.5862069, 83.79310345, 90.]), labels=['Ejecta', 'Coronal Hole', 'Sector Reversal', 'Streamer Belt'], colors=['#191923', '#0E79B2', '#bf1363', '#F39237'])[source]
- sdofmv2.tasks.solar_wind.visualization.plot_ecliptic(img, model, time, x_range=array([-213., -198.31034483, -183.62068966, -168.93103448, -154.24137931, -139.55172414, -124.86206897, -110.17241379, -95.48275862, -80.79310345, -66.10344828, -51.4137931, -36.72413793, -22.03448276, -7.34482759, 7.34482759, 22.03448276, 36.72413793, 51.4137931, 66.10344828, 80.79310345, 95.48275862, 110.17241379, 124.86206897, 139.55172414, 154.24137931, 168.93103448, 183.62068966, 198.31034483, 213.]), y_range=array([-213., -198.31034483, -183.62068966, -168.93103448, -154.24137931, -139.55172414, -124.86206897, -110.17241379, -95.48275862, -80.79310345, -66.10344828, -51.4137931, -36.72413793, -22.03448276, -7.34482759, 7.34482759, 22.03448276, 36.72413793, 51.4137931, 66.10344828, 80.79310345, 95.48275862, 110.17241379, 124.86206897, 139.55172414, 154.24137931, 168.93103448, 183.62068966, 198.31034483, 213.]), z_range=array([0.]), r_mean=np.float32(8.918359e+07), r_std=np.float32(3.0130634e+07), inner_mask=10, outer_mask=220)[source]
- sdofmv2.tasks.solar_wind.visualization.plot_images_grid(correct_data, plt_style, step=None)[source]
- sdofmv2.tasks.solar_wind.visualization.plot_sdoml(datamodule, condition=None, dataset_idx=None, times=None, n_samples=None, wavelengths=['131A', '1600A', '1700A', '171A', '193A', '211A', '304A', '335A', '94A'], components=['Bx', 'By', 'Bz'], wsa_footpoint=False, title=None)[source]
- sdofmv2.tasks.solar_wind.visualization.plot_tsne(embs, position, r_distance, labels, perplexities=[5, 30, 50, 100, 300], classes=['Ejecta', 'Coronal Hole', 'Sector Reversal', 'Streamer Belt'], colors=['#191923', '#0E79B2', '#bf1363', '#F39237'])[source]
- sdofmv2.tasks.solar_wind.visualization.wsa_to_image(lat_sh, lon_sh, r_sun=200, ctr_pxl=256)[source]
Returns position on SDOML image of PSP footpoint from WSA. lat_sh and lon_sh are Stonyhurst latitude and longitude in degrees. r_sun is radius of sun on image in units of pixels. ctr_pixel is the value of the center of each axis in units of pixels.