Source code for sdofmv2.tasks.solar_wind.datamodule

import os

import hydra
import numpy as np
import pandas as pd
import torch
import zarr
from loguru import logger

from sdofmv2.core import SDOMLDataModule, SDOMLDataset


def parse_cadence(cadence):
    """Return time group keys based on cadence."""
    return {
        "30s": ["year", "month", "day", "hour", "minute", "second_bool"],
        "1s": ["year", "month", "day", "hour", "minute", "second"],
        "1min": ["year", "month", "day", "hour", "minute"],
        "1h": ["year", "month", "day", "hour"],
        "1D": ["year", "month", "day"],
        "1MS": ["year", "month"],
        "1YS": ["year"],
    }.get(cadence, [])


[docs] class SWDataset(SDOMLDataset): """Solar Wind dataset for SDOML (Solar Dynamics Observatory Machine Learning). This class extends the base SDOMLDataset to handle solar wind-specific features, including radial and latitudinal/longitudinal parameters. It supports temporal filtering by year/month, class-based undersampling for training sets, and automated column index mapping for coordinate features. Args: aligndata (pd.DataFrame): DataFrame containing aligned timestamps, input data paths, and target labels. hmi_data (zarr.Group): SDO/HMI (magnetic field) data source. aia_data (zarr.Group): SDO/AIA (extreme ultraviolet) data source. eve_data (zarr.Group): SDO/EVE (irradiance) data source. components (list[str]): List of HMI magnetic field components to include. wavelengths (list[int]): List of AIA wavelengths (e.g., 171, 193) to include. ions (list[str]): List of EVE ions or spectral lines to include. freq (str): Data cadence string (e.g., '12min' or '1h'). months (list[int]): List of months (1-12) to include in this data split. years (list[int]): List of years (e.g., 2010-2018) to include in this data split. mask (torch.Tensor, optional): A precomputed spatial mask to be applied to the AIA and HMI image data. Defaults to None. num_frames (int): Number of consecutive temporal frames per sample. Defaults to 1. drop_frame_dim (bool): If True, removes the temporal dimension (F) for single-frame samples. Defaults to False. min_date (str, optional): Minimum timestamp boundary for filtering data. max_date (str, optional): Maximum timestamp boundary for filtering data. get_header (bool): Whether to retrieve and return FITS headers with the data. Defaults to False. normalization (dict): Mapping of instrument keys to their respective normalization methods. normalization_stat (dict): Precomputed statistics (e.g., mean and standard deviation) used for data scaling. label_type (str): The column name in ``aligndata`` used as the prediction target. radial_parameters (list[str], optional): Column names representing radial distance features. latlon_parameters (list[str], optional): Column names representing spatial coordinates (latitude/longitude). sampling_ratio (list[float], optional): Fractions used to undersample specific classes (only applied when ``datasplit`` is "train"). random_state (int, optional): Seed for the random number generator to ensure reproducible undersampling. datasplit (str): The intended use of this instance; one of "train", "val", or "test". Defaults to "train". Attributes: radial_parameters (list[str]): List of column names representing radial distance features. latlon_parameters (list[str]): List of column names representing latitude and longitude features. aligndata (pd.DataFrame): The filtered and potentially undersampled DataFrame containing alignment and label information. id_label (int): The integer column index of the target label within ``aligndata``. position_list (list[int]): List of integer column indices corresponding to the latitudinal and longitudinal parameters. r_dist_list (list[int]): List of integer column indices corresponding to the normalized radial distance parameters. """ def __init__( self, aligndata, hmi_data, aia_data, eve_data, components, wavelengths, ions, freq, months, years, mask=None, num_frames=1, drop_frame_dim=False, min_date=None, max_date=None, get_header=False, normalization={}, normalization_stat={}, # set variables for solar wind here label_type="", radial_parameters=None, latlon_parameters=None, sampling_ratio=None, random_state=None, datasplit="train", ): super().__init__( aligndata=aligndata, hmi_data=hmi_data, aia_data=aia_data, eve_data=eve_data, components=components, wavelengths=wavelengths, ions=ions, freq=freq, months=months, mask=mask, num_frames=num_frames, drop_frame_dim=drop_frame_dim, min_date=min_date, max_date=max_date, get_header=get_header, normalization=normalization, normalization_stat=normalization_stat, ) self.radial_parameters = radial_parameters self.latlon_parameters = latlon_parameters # split data based on month logger.info(f"{datasplit.upper()} set") logger.info(f"Data split, year: {years} & month: {months}") month_condition = aligndata.index.month.isin(months) year_condition = aligndata.index.year.isin(years) self.aligndata = aligndata.loc[month_condition & year_condition, :] label_name = label_type self.id_label = self.aligndata.columns.get_loc(label_name) # undersampling if sampling_ratio true if sampling_ratio is not None and datasplit == "train": return_df = [] for class_id, class_ratio in enumerate(sampling_ratio): logger.info( f"{class_ratio * 100:.0f}% of class: {class_id} instances are sampled!" ) return_df.append( self.aligndata.loc[self.aligndata[label_name] == class_id].sample( frac=class_ratio, replace=False, random_state=random_state ) ) self.aligndata = pd.concat(return_df, axis=0, ignore_index=False) cols = self.aligndata.columns.to_list() # define the position columns self.position_list = [] self.r_dist_list = [] for para in self.latlon_parameters: self.position_list.append(cols.index(f"{para}")) for para in self.radial_parameters: self.r_dist_list.append(cols.index(f"{para}_norm")) logger.info(f"Position list: {self.latlon_parameters}: {self.position_list}") logger.info(f"Radial distance: {self.radial_parameters}: {self.r_dist_list}") logger.info(f"Label: {self.aligndata[label_name].value_counts()}") def __len__(self): # report slightly smaller such that all frame sets requested are available return self.aligndata.shape[0] def __getitem__(self, idx): # start = time.time() label = self.aligndata.iloc[idx, self.id_label].astype( "int64" ) # make it start from 0 position = np.radians(self.aligndata.iloc[idx, self.position_list].values) r_distance = self.aligndata.iloc[idx, self.r_dist_list].to_numpy( dtype=np.float32 ) timestamps = self.aligndata.index[idx].value # second retrieve input (image, or (image, header)) from parent class if self.get_header: image_stack, header_stack, _ = super().__getitem__(idx=idx) # logger.info(f"end: {time.time()} total: {time.time()-start}") return image_stack, timestamps, header_stack, position, r_distance[0], label else: image_stack, timestamps_parent = super().__getitem__(idx=idx) if timestamps_parent != timestamps: logger.warning( f"Parent: {pd.to_datetime(timestamps_parent)} &" f"child: {pd.to_datetime(timestamps)} different!" ) # logger.info(f"end: {time.time()} total: {time.time()-start}") return image_stack, timestamps, position, r_distance[0], label
[docs] class SWDataModule(SDOMLDataModule): """PyTorch Lightning DataModule for Solar Wind (SW) prediction. This module handles the end-to-end data pipeline for SDOML datasets, including loading alignment indices from Zarr, filtering by temporal boundaries (years/months), applying spatial longitude cutoffs for solar footpoints, and managing normalization for radial distance parameters. Args: hmi_path (str): Path to HMI Zarr data. aia_path (str): Path to AIA Zarr data. eve_path (str): Path to EVE Zarr data. components (list[str]): HMI magnetic components to use. wavelengths (list[int]): AIA wavelengths to use. ions (list[str]): EVE spectral lines to use. frequency (str): The sampling frequency of the instruments. batch_size: Number of samples per batch. Defaults to 32. num_workers: Number of subprocesses for data loading. val_months: Months assigned to validation. Defaults to [10, 1]. test_months: Months assigned to testing. Defaults to [11, 12]. holdout_months: Months to be completely excluded from all splits. radial_norm: Whether to apply Z-score normalization to radial features. cache_dir: Directory for storing temporary or cached data. apply_mask: If True, applies the limb mask to spatial data. num_frames: Number of temporal frames per sample. drop_frame_dim: Whether to squeeze the temporal dimension if it is 1. min_date: Global start date boundary. max_date: Global end date boundary. precision: Numerical precision for tensors ("16", "32", or "64"). normalization: Dictionary of normalization settings. cfg: Configuration object for data cutoffs and model settings. train_months: Months assigned to training. train_years: Year(s) assigned to training. val_years: Year(s) assigned to validation. test_years: Year(s) assigned to testing. alignment_indices_path (str): Path to the Zarr file containing alignment indices and metadata. radial_parameters (list[str]): Column names for radial features. latlon_parameters (list[str]): Column names for coordinate features. cadence: Temporal resolution string (e.g., "1min"). label_type: The target column name for the model to predict. sampling_ratio: Fraction of instances to sample per class. random_state: Seed for reproducible sampling and shuffling. Attributes: aligndata (pd.DataFrame): The central alignment table indexed by SDO observation time, containing indices of data and target labels. radial_mean (float): Mean value of the radial parameters used for normalization. radial_std (float): Standard deviation of the radial parameters. train_years (int | list[int]): Years allocated for the training set. val_years (int | list[int]): Years allocated for the validation set. test_years (int | list[int]): Years allocated for the test set. cfg (DictConfig | Any): Configuration object containing hyperparameters and data cutoffs (e.g., ``cfg.data.in_situ.lon_cutoff``). """ def __init__( self, hmi_path, aia_path, eve_path, components, wavelengths, ions, frequency, batch_size: int = 32, num_workers=None, val_months=[10, 1], test_months=[11, 12], holdout_months=[], radial_norm=False, cache_dir="", apply_mask=True, num_frames=1, drop_frame_dim=False, min_date=None, max_date=None, precision="32", normalization=None, # set variables for solar wind here cfg=None, train_months=[10], train_years=2022, val_years=2023, test_years=2018, alignment_indices_path=None, radial_parameters=None, latlon_parameters=None, cadence="1min", label_type="", sampling_ratio=None, random_state=None, ): super().__init__( hmi_path=hmi_path, aia_path=aia_path, eve_path=eve_path, components=components, wavelengths=wavelengths, ions=ions, frequency=frequency, batch_size=batch_size, num_workers=num_workers, val_months=val_months, test_months=test_months, holdout_months=holdout_months, normalization=normalization, cache_dir=cache_dir, apply_mask=apply_mask, num_frames=num_frames, drop_frame_dim=drop_frame_dim, min_date=min_date, max_date=max_date, precision=precision, ) self.cfg = cfg self.alignment_indices_path = alignment_indices_path self.cadence = cadence self.train_months = train_months self.label_type = label_type self.sampling_ratio = sampling_ratio self.random_state = random_state self.train_years = train_years self.val_years = val_years self.test_years = test_years self.precision = precision self.radial_parameters = radial_parameters self.latlon_parameters = latlon_parameters self.radial_norm = radial_norm # Loading alignment data from zarr file if os.path.exists(self.alignment_indices_path): logger.info(f"Alignment file is found: {self.alignment_indices_path}") root = zarr.open(self.alignment_indices_path, mode="r") columns = root.attrs["columns"] self.aligndata = pd.DataFrame(root[:, :], columns=columns) else: self.create_alignment_data() self.aligndata["time_sdo_loc_est"] = pd.to_datetime( self.aligndata["time_sdo_loc_est"], unit="s" ) self.aligndata["time_sdo_loc_est"] = self.aligndata[ "time_sdo_loc_est" ].dt.round(freq="s") self.aligndata.set_index("time_sdo_loc_est", inplace=True) # TODO: add switch + angle cutoff to cfg file if "lon_footpoint" in self.latlon_parameters: self.aligndata = self.aligndata.loc[ self.aligndata["lon_footpoint"].abs() < self.cfg.data.in_situ.lon_cutoff, :, ] # Cut data to just those with magnetic footpoints on the visible solar disk elif "sc_pos_SH_lon" in self.latlon_parameters: self.aligndata = self.aligndata.loc[ self.aligndata["sc_pos_SH_lon"].abs() < self.cfg.data.in_situ.lon_cutoff, :, ] # Cut data to just those with PSP position on the visible solar disk self.aligndata.sort_index(inplace=True) # normalize float values if radial_norm is not None: for id_col, col in enumerate(self.radial_parameters): # self.radial_parameters.append(col) logger.info(f"Normalizing column: {col}") self.radial_mean = self.aligndata[col].mean() self.radial_std = self.aligndata[col].std() self.aligndata.loc[:, f"{col}_norm"] = ( self.aligndata[col] - self.radial_mean ) / self.radial_std def create_alignment_data(self): logger.info("Creating alignment dataset") # loading source files path = os.path.join( self.cfg.data.in_situ.base_data_directory, self.cfg.data.in_situ.psp_interpolated_path, ) root = zarr.open(path, mode="r") # preprocessing psp data files columns = root.attrs["columns"] df_psp = pd.DataFrame(root[:, :], columns=columns) df_psp["time"] = pd.to_datetime(root.attrs["time"]) df_psp.dropna( subset=self.radial_parameters + self.latlon_parameters, inplace=True ) # missing values from spc data # call the propagation type and covert it to datetime dype df_psp[self.cfg.experiment.propagation_type] = df_psp[ self.cfg.experiment.propagation_type ].apply(lambda x: pd.Timedelta(x, unit="seconds")) # this timestamp ("time_sdo_loc_est") is used for matching sdoml and psp data df_psp["time_sdo_loc_est"] = ( df_psp["time"] - df_psp[self.cfg.experiment.propagation_type] ) # preprocessing sdoml data # sdoml data should start from when the psp data start (we use some buffer of 4 days) self.aligndata.reset_index(inplace=True) self.aligndata.rename({"index": "Time"}, inplace=True) self.aligndata = self.aligndata.loc[ self.aligndata["Time"] >= df_psp["time_sdo_loc_est"].min() - pd.Timedelta(days=4), :, ] # sort dataframes before meging them df_psp.sort_values(by="time_sdo_loc_est", inplace=True) self.aligndata.sort_values(by="Time", inplace=True) # find the nearest timstamps (from right dataframe) based on left key df_merge = pd.merge_asof( df_psp, self.aligndata, left_on="time_sdo_loc_est", right_on="Time", direction="nearest", allow_exact_matches=True, tolerance=pd.Timedelta(minutes=int(self.cfg.data.in_situ.match_tolerance)), suffixes=("", "_sdoml"), ) # sort merged dataframe, which is reodered by merge # we set vp_fit_RTN_0_mean < 100 as outlier df_merge.sort_values(by="time_sdo_loc_est", inplace=True) df_merge = df_merge.loc[df_merge["vp_fit_RTN_0_mean"] >= 100, :] # if we do not find nearest timestamps between two dataframe, # we drop those rows df_merge.dropna(subset=["Time"], inplace=True) df_merge[self.cfg.experiment.propagation_type] = df_merge[ self.cfg.experiment.propagation_type ].dt.total_seconds() # Save the data to zarr format # only numerical columns can be saved in zarr obj_cols = df_merge.select_dtypes(exclude="number").columns.to_list() logger.info(f"Object columns: {obj_cols} is converted to int") for col in obj_cols: df_merge[col] = df_merge[col].values.astype(np.int64) / 10**9 # # df_merge[col] = df_merge[col].dt.timestamp() num_cols = df_merge.select_dtypes(include="number").columns.to_list() z1 = zarr.open( self.alignment_indices_path, mode="w", shape=(len(df_merge), len(num_cols)), chunks=(20_000, len(num_cols)), dtype="f8", ) z1[:, :] = df_merge[num_cols].to_numpy().astype(float) z1.attrs["columns"] = num_cols logger.info(f"Alignment data is saved: {self.alignment_indices_path}") self.aligndata = df_merge[num_cols] def setup(self, stage=None): # trainset self.train_ds = SWDataset( self.aligndata, self.hmi_data, self.aia_data, self.eve_data, self.components, self.wavelengths, self.ions, self.cadence, self.train_months, years=self.train_years, mask=self.hmi_mask.numpy(), num_frames=self.num_frames, drop_frame_dim=self.drop_frame_dim, min_date=self.min_date, max_date=self.max_date, normalization=self.normalization, normalization_stat=self.normalization_stat, # set variables for solar wind here radial_parameters=self.radial_parameters, latlon_parameters=self.latlon_parameters, label_type=self.label_type, sampling_ratio=self.sampling_ratio, random_state=self.random_state, datasplit="train", ) # validation set self.valid_ds = SWDataset( self.aligndata, self.hmi_data, self.aia_data, self.eve_data, self.components, self.wavelengths, self.ions, self.cadence, self.val_months, years=self.val_years, mask=self.hmi_mask.numpy(), num_frames=self.num_frames, drop_frame_dim=self.drop_frame_dim, min_date=self.min_date, max_date=self.max_date, normalization=self.normalization, normalization_stat=self.normalization_stat, # set variables for solar wind here radial_parameters=self.radial_parameters, latlon_parameters=self.latlon_parameters, label_type=self.label_type, sampling_ratio=None, random_state=None, datasplit="val", ) # testset self.test_ds = SWDataset( self.aligndata, self.hmi_data, self.aia_data, self.eve_data, self.components, self.wavelengths, self.ions, self.cadence, self.test_months, years=self.test_years, mask=self.hmi_mask.numpy(), num_frames=self.num_frames, drop_frame_dim=self.drop_frame_dim, min_date=self.min_date, max_date=self.max_date, normalization=self.normalization, normalization_stat=self.normalization_stat, # set variables for solar wind here radial_parameters=self.radial_parameters, latlon_parameters=self.latlon_parameters, label_type=self.label_type, sampling_ratio=None, random_state=None, datasplit="test", ) # testset self.predict_ds = SWDataset( self.aligndata, self.hmi_data, self.aia_data, self.eve_data, self.components, self.wavelengths, self.ions, self.cadence, self.test_months, years=self.test_years, mask=self.hmi_mask.numpy(), num_frames=self.num_frames, drop_frame_dim=self.drop_frame_dim, min_date=self.min_date, max_date=self.max_date, normalization=self.normalization, normalization_stat=self.normalization_stat, # set variables for solar wind here radial_parameters=self.radial_parameters, latlon_parameters=self.latlon_parameters, label_type=self.label_type, sampling_ratio=None, random_state=None, datasplit="predict", ) def train_dataloader(self): return torch.utils.data.DataLoader( self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True, ) def val_dataloader(self): return torch.utils.data.DataLoader( self.valid_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, # shuffle true for visualization ) def test_dataloader(self): return torch.utils.data.DataLoader( self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers ) def predict_dataloader(self): return torch.utils.data.DataLoader( self.predict_ds, batch_size=self.batch_size, num_workers=self.num_workers )
@hydra.main( version_base=None, config_path="../configs", config_name="finetune_solarwind_config" ) def main(cfg): """Initializes the solar wind data module and validates dataset alignment. This function sets up the SWDataModule using parameters from the Hydra configuration. It verifies the training dataset length, checks frame range accessibility, and retrieves a sample to ensure the data pipeline works correctly. Args: cfg (DictConfig): Hydra configuration object containing data paths, split definitions, and experiment parameters. Returns: None """ datamodule = SWDataModule( hmi_path=( os.path.join( cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi, ) if cfg.data.sdoml.sub_directory.hmi else None ), aia_path=( os.path.join( cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia, ) if cfg.data.sdoml.sub_directory.aia else None ), normalization=cfg.data.normalization, eve_path=( os.path.join( cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.eve, ) if cfg.data.sdoml.sub_directory.eve else None ), components=cfg.data.sdoml.components, wavelengths=cfg.data.sdoml.wavelengths, ions=cfg.data.sdoml.ions, frequency=cfg.data.sdoml.frequency, batch_size=cfg.model.opt.batch_size, num_workers=cfg.data.num_workers, val_months=cfg.data.month_splits.val, train_months=cfg.data.month_splits.train, test_months=cfg.data.month_splits.test, train_years=cfg.data.year_splits.train, val_years=cfg.data.year_splits.val, test_years=cfg.data.year_splits.test, holdout_months=cfg.data.month_splits.holdout, cache_dir=os.path.join( cfg.data.sdoml.save_directory, cfg.data.sdoml.sub_directory.cache ), min_date=cfg.data.min_date, max_date=cfg.data.max_date, num_frames=cfg.data.num_frames, drop_frame_dim=cfg.data.drop_frame_dim, alignment_indices_path=cfg.data.in_situ.base_data_directory + cfg.data.in_situ.alignment_indices_path, parameters=cfg.data.in_situ.parameters, cadence="1min", label_type=cfg.experiment.label_type, sampling_ratio=None, cfg=cfg, ) datamodule.setup() # Check dataset and data alignment ds = datamodule.train_ds print(f"Dataset __len__: {len(ds)}") print(f"Aligndata rows: {len(ds.aligndata)}") print(f"Frame range: {getattr(ds, 'frame_range', [0])}") # Check what index 0 + frame would be frame_range = getattr(ds, "frame_range", [0]) for frame in frame_range: target_idx = 0 + frame print(f"Trying to access index: {target_idx}") if target_idx >= len(ds.aligndata): print(f"ERROR: Index {target_idx} >= DataFrame size {len(ds.aligndata)}") image, timestamps, header, label, position = datamodule.train_ds[0] if __name__ == "__main__": # cfg = omegaconf.OmegaConf.load( # ("/home/jh/project/2025-HL-Solar-Wind/classification" # "/configs/finetune_solarwind_config.yaml") # ) main()