# Adapted to be general from https://github.com/FrontierDevelopmentLab/2023-FDL-X-ARD-EVE/blob/main/src/irradiance/utilities/data_loader.py
import os
from pathlib import Path
import re
import time
from pathlib import Path
from loguru import logger
import torch
import yaml
import dask.array as da
from dask.diagnostics import ProgressBar
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import zarr
from ..utils import ALL_COMPONENTS, ALL_IONS, ALL_WAVELENGTHS
def get_dtype_from_precision(precision):
if str(precision) in ["16", "16-mixed"]:
return torch.float16
elif str(precision) in ["bf16", "bf16-mixed"]:
return torch.bfloat16
elif str(precision) in ["64", "64-true"]:
return torch.float64
else:
return torch.float32
def zscore_norm(data, channel, normalization_stat, clip_value):
if clip_value is not None:
low, high = clip_value
data = np.clip(data, low, high)
data -= normalization_stat[channel]["mean"]
data /= normalization_stat[channel]["std"]
return data
def min_max_norm(data, channel, normalization_stat):
data -= normalization_stat[channel]["min"]
data /= normalization_stat[channel]["max"] - normalization_stat[channel]["min"]
return data
def log_norm(data, normalization_stat, channel, scaler_factor):
x = data * scaler_factor if scaler_factor is not None else data
# Log transform
x_log = np.sign(x) * np.log1p(np.abs(x))
# zscore norm
x_transformed = (x_log - normalization_stat[channel]["mean"]) / (
normalization_stat[channel]["std"] + 1e-8
)
return x_transformed
def inverse_zscore_norm(data, instrument, channel, normalization_stat):
# Reverse the division
data = data * normalization_stat[instrument][channel]["std"]
# Reverse the subtraction
data = data + normalization_stat[instrument][channel]["mean"]
return data
def inverse_log_norm(
data_transformed,
normalization_stat,
channel,
scaler_factor=None,
):
# Retrieve the exact log-domain statistics used during forward normalization
mean = normalization_stat[channel]["mean"]
std = normalization_stat[channel]["std"]
# Reverse the Z-score standardization
# x_transformed = (x_log - mean) / std -> x_log = (x_transformed * std) + mean
x_log = (data_transformed * (std + 1e-8)) + mean
# Reverse the SymLog Transform
# The inverse of y = sign(x) * log(1 + |x|) is x = sign(y) * (exp(|y|) - 1)
x = np.sign(x_log) * np.expm1(np.abs(x_log))
# Reverse the Pre-scaling
if scaler_factor is not None:
data_original = x / scaler_factor
else:
data_original = x
return data_original
[docs]
class SDOMLDataset(Dataset):
"""A PyTorch Dataset for Solar Dynamics Observatory (SDO) Machine Learning data.
This dataset aligns and loads multimodal solar observations from the AIA, HMI,
and EVE instruments. It supports temporal sequencing, masking, and on-the-fly
normalization for training deep learning models on solar data.
Args:
aligndata (pd.DataFrame): Aligned temporal indexes used
for matching inputs and outputs across different instruments.
hmi_data (zarr.hierarchy.Group): Zarr dataset
HMI magnetogram observations.
aia_data (zarr.hierarchy.Group): Zarr dataset
AIA EUV/UV image observations.
eve_data (zarr.hierarchy.Group): Zarr dataset
EVE irradiance observations.
components (list[str]): List of magnetic components to load for HMI
(e.g., ['Bx', 'By', 'Bz']).
wavelengths (list[str] or list[int]): List of channels to load for AIA
(e.g., [94, 131, 171, 193, 211, 304, 335, 1600, 1700]).
ions (list[str]): List of spectral lines/ions to load for EVE
(e.g., from MEGS-A and MEGS-B).
freq (str): The temporal cadence used for rounding and aligning the
time series (e.g., '12min').
months (list[int]): List of valid months (1-12) to include in the dataset.
Useful for creating train/validation/test splits by time.
normalization (dict): The normalization strategy to apply
during data loading (e.g., 'zscore', 'minmax'). Defaults to None.
normalization_stat (dict): Pre-computed statistics (like mean
and standard deviation) required for the chosen normalization.
Defaults to None.
mask (torch.Tensor): Whether to apply the HMI limb
mask to the AIA and HMI spatial data. Defaults to None.
num_frames (int, optional): The number of consecutive temporal frames
to load per sequence sample. Defaults to 1.
drop_frame_dim (bool, optional): If True and `num_frames` is 1, drops
the temporal dimension. Defaults to False.
min_date (str or datetime, optional): The earliest date boundary to
include in the dataset. Defaults to None.
max_date (str or datetime, optional): The latest date boundary to
include in the dataset. Defaults to None.
get_header (bool or list, optional): Whether to retrieve and return header metadata alongside the image tensors. Defaults to False.
precision (str, optional): The floating-point precision for the output
tensors (e.g., "32" for float32, "16" for float16). Defaults to "32".
"""
def __init__(
self,
aligndata,
hmi_data,
aia_data,
eve_data,
components,
wavelengths,
ions,
normalization=None,
normalization_stat=None,
mask=None,
num_frames=1,
drop_frame_dim=False,
get_header=False, # Optional[list] = [],
precision="32",
):
super().__init__()
self.aligndata = aligndata
self.aia_data = aia_data
self.eve_data = eve_data
self.hmi_data = hmi_data
self.mask = mask
self.get_header = get_header
self.precision = precision
# Select alls
self.components = components
self.wavelengths = wavelengths
self.ions = ions
# Loading data
# HMI
if self.hmi_data is not None:
if self.components is None:
self.components = ALL_COMPONENTS
self.components.sort()
# AIA
if self.aia_data is not None:
if self.wavelengths is None:
self.wavelengths = ALL_WAVELENGTHS
self.wavelengths.sort()
# EVE
if self.eve_data is not None:
if self.ions is None:
self.ions = ALL_IONS
self.ions.sort()
self.normalization = normalization
self.normalization_stat = normalization_stat
# number of frames to return per sample
self.num_frames = num_frames
self.drop_frame_dim = drop_frame_dim # for backwards compat
if self.drop_frame_dim:
assert self.num_frames == 1
def __len__(self):
# report slightly smaller such that all frame sets requested are available
return len(self.aligndata) - (self.num_frames - 1)
def __getitem__(self, idx):
image_stack = None
header_stack = {}
if self.aia_data is not None:
aia_images, aia_headers = self.get_aia_image(idx)
image_stack = aia_images
header_stack.update(aia_headers)
if self.hmi_data is not None:
hmi_images, hmi_headers = self.get_hmi_image(idx)
if image_stack is None:
image_stack = hmi_images
else:
image_stack = np.concatenate((image_stack, hmi_images), axis=0)
header_stack.update(hmi_headers)
image_stack = torch.from_numpy(image_stack)
image_stack = image_stack.to(get_dtype_from_precision(self.precision))
timestamps = self.aligndata.index[idx : idx + self.num_frames].astype("int")
timestamps = timestamps[0] if self.num_frames <= 1 else timestamps
if not self.get_header:
if self.eve_data is not None:
eve_data = self.get_eve(idx)
return image_stack, timestamps, eve_data
else:
return image_stack, timestamps
else:
if self.eve_data is not None:
eve_data = self.get_eve(idx)
return (
image_stack,
timestamps,
header_stack,
eve_data.reshape(-1),
)
else:
return image_stack, timestamps, header_stack
def _data_norm(self, data, instrument, channel):
"""
data: numpy array of shape H W
"""
if self.normalization.type == "log":
return log_norm(
data,
self.normalization_stat,
channel,
self.normalization.scaler_factor,
)
elif self.normalization.type == "zscore":
return zscore_norm(
data,
channel,
self.normalization_stat,
(
self.normalization.clipping[channel]
if self.normalization.clipping.enabled
else None
),
)
elif self.normalization.type == "min-max":
return min_max_norm(data, channel, self.normalization_stat)
[docs]
def loading_data_retry(
self,
data,
year,
wavelength,
id_of_img,
num_try: int = 10,
sleep_time: float = 0.5,
):
"""
Tries to load an image from the dataset multiple times to handle transient
"""
last_error = None
for attempt in range(num_try):
try:
# Attempt to slice the data (triggering decompression)
img = data[year][wavelength][id_of_img, :, :]
return img
except Exception as e:
# Store error, log warning, and wait before retrying
last_error = e
logger.warning(
f"Corrupted load (Attempt {attempt + 1}/{num_try}) - "
f"channel: {wavelength}, year: {year}, idx: {id_of_img}. Error: {e}"
)
time.sleep(sleep_time)
# If the loop finishes, we failed 'num_try' times.
# Raise the last error to stop execution (or return zeros if preferred).
logger.error(f"PERMANENT FAILURE: Could not load data after {num_try} attempts.")
raise last_error
[docs]
def get_aia_image(self, idx):
"""Get AIA image for a given index.
Returns a numpy array of shape (num_wavelengths, num_frames, height, width).
"""
aia_image_dict = {}
aia_header_dict = {}
for wavelength in self.wavelengths:
aia_image_dict[wavelength] = []
if self.get_header:
aia_header_dict[wavelength] = []
for frame in range(self.num_frames):
idx_row_element = self.aligndata.iloc[idx + frame]
idx_wavelength = idx_row_element[f"idx_{wavelength}"].astype(int)
year = str(idx_row_element.name.year)
img = self.loading_data_retry(
self.aia_data, year, wavelength, idx_wavelength, 10, 0.5
)
if self.mask is not None:
img = img * self.mask
aia_image_dict[wavelength].append(img)
if self.get_header:
try:
aia_header_dict[wavelength].append(
{
keys: values[idx_wavelength]
for keys, values in self.aia_data[year][wavelength].attrs.items()
}
)
except:
aia_header_dict[wavelength].append(None)
if self.normalization.enabled:
aia_image_dict[wavelength][-1] = self._data_norm(
aia_image_dict[wavelength][-1], "AIA", wavelength
)
aia_image = np.array(list(aia_image_dict.values()))
return (
(aia_image[:, 0, :, :], aia_header_dict)
if self.drop_frame_dim
else (aia_image, aia_header_dict)
)
[docs]
def get_hmi_image(self, idx):
"""Get HMI image for a given index.
Returns a numpy array of shape (num_channels, num_frames, height, width).
"""
hmi_image_dict = {}
hmi_header_dict = {}
for component in self.components:
hmi_image_dict[component] = []
if self.get_header:
hmi_header_dict[component] = []
for frame in range(self.num_frames):
idx_row_element = self.aligndata.iloc[idx + frame]
idx_component = idx_row_element[f"idx_{component}"].astype(int)
year = str(idx_row_element.name.year)
img = self.loading_data_retry(
self.hmi_data, year, component, idx_component, 10, 0.5
)
if self.mask is not None:
img = img * self.mask
hmi_image_dict[component].append(img)
if self.get_header:
hmi_header_dict[component].append(
{
keys: values[idx_component]
for keys, values in self.hmi_data[year][component].attrs.items()
}
)
if self.normalization.enabled:
hmi_image_dict[component][-1] = self._data_norm(
hmi_image_dict[component][-1], "HMI", component
)
hmi_image = np.array(list(hmi_image_dict.values()))
return (
(hmi_image[:, 0, :, :], hmi_header_dict)
if self.drop_frame_dim
else (hmi_image, hmi_header_dict)
)
[docs]
def get_eve(self, idx):
"""Get EVE data for a given index.
Returns a numpy array of shape (num_ions, num_frames, ...).
"""
eve_ion_dict = {}
for ion in self.ions:
eve_ion_dict[ion] = []
for frame in range(self.num_frames):
idx_eve = self.aligndata.iloc[idx + frame]["idx_eve"]
eve_ion_dict[ion].append(self.eve_data[ion][idx_eve])
if self.normalization.enabled:
eve_ion_dict[ion][-1] = self._data_norm(eve_ion_dict[ion][-1], "EVE", ion)
eve_data = np.array(list(eve_ion_dict.values()), dtype=np.float32)
return eve_data
def __str__(self):
output = ""
for k, v in self.__dict__.items():
output += f"{k}: {v}\n"
return output
[docs]
class SDOMLDataModule(pl.LightningDataModule):
"""A PyTorch Lightning DataModule for paired SDO machine learning data.
This module orchestrates the downloading, setup, splitting, and batching of
paired AIA EUV images, HMI magnetograms, and EVE irradiance measures. It
handles train/val/test splits based on specified months to prevent temporal
data leakage.
Note:
Input data across the different instruments needs to be temporally aligned
and paired.
Args:
hmi_path (str): Path to the HMI Zarr data file.
aia_path (str): Path to the AIA Zarr data file.
eve_path (str): Path to the EVE Zarr data file.
components (list[str]): List of magnetic field components to load from HMI.
wavelengths (list[int] or list[str]): List of AIA wavelengths to load.
ions (list[str]): List of EVE ions or spectral lines to load.
batch_size (int, optional): Number of samples per batch. Defaults to 32.
num_workers (int, optional): Number of subprocesses to use for data
loading. Defaults to None.
pin_memory (bool, optional): If True, the data loader will copy Tensors
into CUDA pinned memory before returning them. Defaults to False.
persistent_workers (bool, optional): If True, the data loader will not
shutdown worker processes after a dataset has been consumed once.
Defaults to False.
normalization (dict): specific normalization strategy to use. Defaults to False.
hmi_mask (str, optional): Filename for the HMI mask. Defaults to "hmi_mask_512x512.npy".
apply_mask (bool, optional): Whether to apply the solar limb mask to the
spatial data. Defaults to True.
num_frames (int, optional): The number of consecutive temporal frames
to load per sequence sample. Defaults to 1.
drop_frame_dim (bool, optional): If True and `num_frames` is 1. Defaults to False.
min_date (str, optional): The earliest date boundary to include in the
splits (e.g., '2010-05-01'). Defaults to None.
max_date (str, optional): The latest date boundary to include in the
splits. Defaults to None.
precision (str, optional): The floating-point precision for the output
tensors (e.g., "32", "16"). Defaults to "32".
"""
def __init__(
self,
hmi_path,
aia_path,
eve_path,
components,
wavelengths,
ions,
batch_size: int = 32,
num_workers=None,
pin_memory=False,
persistent_workers=False,
normalization={},
normalization_stat_path="",
train_index="",
val_index="",
test_index="",
hmi_mask="hmi_mask_512x512.npy",
apply_mask=True,
num_frames=1,
drop_frame_dim=False,
precision="32",
):
super().__init__()
self.num_workers = num_workers if num_workers is not None else os.cpu_count() // 2
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.hmi_path = hmi_path
self.aia_path = aia_path
self.eve_path = eve_path
self.batch_size = batch_size
self.apply_mask = apply_mask
self.num_frames = num_frames
self.drop_frame_dim = drop_frame_dim
self.isAIA = True if self.aia_path is not None else False
self.isHMI = True if self.hmi_path is not None else False
self.isEVE = True if self.eve_path is not None else False
self.precision = precision
# index data
# the indices of all data channels with timestamps
self.train_index = train_index
self.val_index = val_index
self.test_index = test_index
# Select alls
self.components = components
self.wavelengths = wavelengths
self.ions = ions
# checking if AIA is in the dataset
if self.isAIA:
self.aia_data = zarr.group(zarr.DirectoryStore(self.aia_path))
if self.wavelengths is None:
self.wavelengths = ALL_WAVELENGTHS
else:
self.aia_data = None
# checking if HMI is in the dataset
if self.isHMI:
self.hmi_data = zarr.group(zarr.DirectoryStore(self.hmi_path))
if self.components is None:
self.components = ALL_COMPONENTS
else:
self.hmi_data = None
# checking if EVE is in the dataset
if self.isEVE:
self.eve_data = zarr.group(zarr.DirectoryStore(self.eve_path))
if self.ions is None:
self.ions = ALL_IONS
else:
self.eve_data = None
# Preprocessed data paths
self.hmi_mask = hmi_mask
self.normalization = normalization
self.normalization_stat_path = normalization_stat_path
self.timeinterval = re.compile(
r"(\d{4}-\d{2}-\d{2}\d{2}:\d{2}:\d{2}-\d{4}-\d{2}-\d{2}\d{2}:\d{2}:\d{2})"
)
self.normalization_stat = self._load_norm_stats()
def _load_norm_stats(self):
match = self.timeinterval.search(self.train_index)
if not match:
raise ValueError("Can't find statistics for normalization")
time_range = match.group(1)
base_path = Path(self.normalization_stat_path)
pattern = f"*{time_range}_norm-{self.normalization.type}*.json"
files = list(base_path.glob(pattern))
if not files:
raise FileNotFoundError(f"No normalization stats found for {time_range}")
stats = {}
for f in files:
ch = f.stem.split("_")[0] # safer than splitting full path
with f.open("r") as fp:
stats[ch] = yaml.safe_load(fp)
return stats
def __str__(self):
output = ""
for k, v in self.__dict__.items():
output += f"{k}: {v}\n"
return output
def setup(self, stage=None):
# Load mask
if self.apply_mask:
if os.path.exists(self.hmi_mask):
self.hmi_mask = torch.Tensor(np.load(self.hmi_mask))
else:
logger.warning(f"HMI mask not found at {self.hmi_mask}, applying no mask.")
self.hmi_mask = None
else:
self.hmi_mask = None
# Define mask for dataset (numpy array or None)
mask_np = self.hmi_mask.numpy() if self.hmi_mask is not None else None
# Note: Dataset now expects a single aligndata and no months filtering (pre-split)
# We pass the specific split aligndata and None for months to disable filtering
self.train_ds = SDOMLDataset(
self._load_aligndata(self.train_index),
self.hmi_data,
self.aia_data,
self.eve_data,
self.components,
self.wavelengths,
self.ions,
normalization=self.normalization,
normalization_stat=self.normalization_stat,
mask=mask_np,
num_frames=self.num_frames,
drop_frame_dim=self.drop_frame_dim,
precision=self.precision,
)
if stage == "fit" or stage is None:
logger.info("Train dataloader is ready!")
logger.info(f"Dataset size: {len(self.train_ds)}")
self.valid_ds = SDOMLDataset(
self._load_aligndata(self.val_index),
self.hmi_data,
self.aia_data,
self.eve_data,
self.components,
self.wavelengths,
self.ions,
normalization=self.normalization,
normalization_stat=self.normalization_stat,
mask=mask_np,
num_frames=self.num_frames,
drop_frame_dim=self.drop_frame_dim,
precision=self.precision,
)
if stage == "fit" or stage is None:
logger.info("Validation dataloader is ready!")
logger.info(f"Dataset size: {len(self.valid_ds)}")
self.test_ds = SDOMLDataset(
self._load_aligndata(self.test_index),
self.hmi_data,
self.aia_data,
self.eve_data,
self.components,
self.wavelengths,
self.ions,
normalization=self.normalization,
normalization_stat=self.normalization_stat,
mask=mask_np,
num_frames=self.num_frames,
drop_frame_dim=self.drop_frame_dim,
precision=self.precision,
)
if stage == "fit" or stage is None:
logger.info("test dataloader is ready!")
logger.info(f"Dataset size: {len(self.test_ds)}")
# Handle predict dataset if needed
# This is optional and depends on the presence of predict months in aligndata_files
if stage == "predict":
self.predict_ds = SDOMLDataset(
self._load_aligndata(self.test_index),
self.hmi_data,
self.aia_data,
self.eve_data,
self.components,
self.wavelengths,
self.ions,
normalization=self.normalization,
normalization_stat=self.normalization_stat,
mask=mask_np,
num_frames=self.num_frames,
drop_frame_dim=self.drop_frame_dim,
precision=self.precision,
)
logger.info("Predict dataloader is ready!")
logger.info(f"Dataset size: {len(self.predict_ds)}")
def _load_aligndata(self, filename):
if not os.path.exists(filename):
raise FileNotFoundError(f"Aligndata file not found: {filename}")
df = pd.read_csv(filename)
df["Time"] = pd.to_datetime(df["Time"])
df.set_index("Time", inplace=True)
return df
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.valid_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
)