Source code for sdofmv2.core.mae_module

import os
import h5py
import torch
import numpy as np
import pandas as pd
import lightning.pytorch as pl
import torch.nn.functional as F
from skimage.measure import block_reduce

from . import reconstruction as bench_recon
from .mae3d import MaskedAutoencoderViT3D
from .basemodule import BaseModule
from ..utils import unpatchify, patchify
from sdofmv2.utils.constants import ALL_WAVELENGTHS


[docs] class MAE(BaseModule): """Masked Autoencoder (MAE) for 3D/Spatiotemporal data reconstruction. This module implements a Vision Transformer-based autoencoder that learns representations by reconstructing masked patches of volumetric data. It supports custom ROI masking (limb masking) and automated metric tracking across training, validation, and testing phases. Args: img_size: Side length of the input image (assumed square). chan_types: List of channel names/wavelengths for logging. patch_size: Spatial size of the 2D patches. num_frames: Total number of frames (temporal depth) in the input sequence. tubelet_size: Temporal size of the 3D tubelets. in_chans: Number of input data channels. embed_dim: Embedding dimension for the encoder. depth: Number of transformer layers in the encoder. num_heads: Number of attention heads in the encoder. decoder_embed_dim: Embedding dimension for the decoder. decoder_depth: Number of transformer layers in the decoder. decoder_num_heads: Number of attention heads in the decoder. mlp_ratio: Expansion ratio for the MLP hidden dimension. norm_layer: Type of normalization layer to use (e.g., "LayerNorm"). masking_ratio: Fraction of patches to mask (0.0 to 1.0). limb_mask: An optional binary ROI mask. loss_dict: Configuration for reconstruction losses. optimizer_dict: Configuration for the optimizer. scheduler_dict: Configuration for the learning rate scheduler. *args: Variable length argument list passed to BaseModule. **kwargs: Arbitrary keyword arguments passed to BaseModule. Attributes: img_size (int): Spatial resolution of the input images (Height and Width). patch_size (int): The side length of the square patches extracted from each frame. tubelet_size (int): The temporal depth of each 3D patch (number of frames). masking_ratio (float): The fraction of patches to be masked out during the forward pass (typically 0.75). chan_types (list[str]): A list of identifiers for each input channel (e.g., specific wavelengths), used for per-channel metric logging. limb_mask (Optional[torch.Tensor]): A binary spatial mask of shape (H, W) used to restrict the model's focus to specific ROIs. loss_dict (dict): Configuration parameters and weights for the reconstruction loss functions. validation_metrics (list[dict]): A transient buffer that accumulates metric dictionaries from each `validation_step` to be processed at the epoch end. test_results (list[dict]): A transient buffer that accumulates metric dictionaries from each `test_step`. autoencoder (MaskedAutoencoderViT3D): The core transformer architecture consisting of the encoder and decoder blocks. """ def __init__( self, # MAE specific img_size=224, chan_types=ALL_WAVELENGTHS, patch_size=16, num_frames=3, tubelet_size=1, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer="LayerNorm", masking_ratio=0.75, limb_mask=None, loss_dict={}, optimizer_dict={}, scheduler_dict={}, # pass to BaseModule *args, **kwargs, ): super().__init__( optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict, *args, **kwargs, ) self.save_hyperparameters() self.img_size = img_size self.patch_size = patch_size self.tubelet_size = tubelet_size self.validation_metrics = [] self.masking_ratio = masking_ratio self.chan_types = chan_types self.limb_mask = limb_mask self.loss_dict = loss_dict self.test_results = [] # block reduce limb_mask limb_mask_ids = None if limb_mask is not None: new_matrix = block_reduce( limb_mask.numpy(), block_size=(self.patch_size, self.patch_size), func=np.max, ) limb_mask_ids = torch.tensor( np.argwhere(new_matrix.reshape((img_size // self.patch_size) ** 2) == 0).reshape(-1) ) self.autoencoder = MaskedAutoencoderViT3D( img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim, depth, num_heads, decoder_embed_dim, decoder_depth, decoder_num_heads, mlp_ratio, norm_layer, limb_mask, limb_mask_ids, loss_dict, )
[docs] def training_step(self, batch, batch_idx): """Perform a single training step. Args: batch: A tuple containing (images, timestamps). batch_idx: The index of the current batch. Returns: torch.Tensor: The training loss value. """ x, timestamps = batch[:2] loss, x_hat, mask = self.autoencoder(x, mask_ratio=self.masking_ratio) self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss
[docs] def validation_step(self, batch, batch_idx): """Perform a single validation step. Args: batch: A tuple containing (images, timestamps). batch_idx: The index of the current batch. """ x, timestamps = batch[:2] loss, x_hat, mask = self.autoencoder(x, mask_ratio=self.masking_ratio) x_hat_reconstructed = unpatchify(x_hat, self.img_size, self.patch_size, self.tubelet_size) # Transfer both tensors to CPU once, outside all loops x_np = x.detach().cpu().numpy() # [B, C, T, H, W] x_hat_np = x_hat_reconstructed.detach().cpu().numpy() batch_size = mask.shape[0] num_frames = x.shape[2] if self.limb_mask is not None: # Build full-resolution mask once for the entire batch: [B, 512, 512] grid_size = self.img_size // self.patch_size mask_full = ( mask.reshape(batch_size, grid_size, grid_size) .detach() .cpu() .numpy() .repeat(self.patch_size, axis=1) .repeat(self.patch_size, axis=2) .astype(bool) ) step_metrics = [ bench_recon.get_metrics_for_masked_patches( x_np[i, :, frame, mask_full[i]], x_hat_np[i, :, frame, mask_full[i]], self.chan_types, ) for i in range(batch_size) for frame in range(num_frames) ] x_patchified = patchify(x, self.patch_size, self.tubelet_size) active_mask = mask == 1 masked_mse = F.mse_loss(x_patchified[active_mask], x_hat[active_mask]) else: step_metrics = [ bench_recon.get_metrics( x_np[i, :, frame, :, :], x_hat_np[i, :, frame, :, :], self.chan_types, ) for i in range(batch_size) for frame in range(num_frames) ] self.validation_metrics.extend(step_metrics) self.log("val_loss", loss) if self.limb_mask is not None: self.log("val_MSEloss_in_masked_patches", masked_mse)
[docs] def forward(self, x, mask_ratio=None): """Perform a forward pass through the MAE. Args: x (torch.Tensor): Input images of shape (B, C, H, W). mask_ratio (float, optional): Fraction of patches to mask. If None, uses the default masking_ratio. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - x_hat: Reconstructed images. - mask: The applied mask tensor. """ if mask_ratio is None: mask_ratio = self.masking_ratio loss, x_hat, mask = self.autoencoder(x, mask_ratio=mask_ratio) x_hat = unpatchify(x_hat, self.img_size, self.patch_size, self.tubelet_size) return x_hat, mask
[docs] def forward_encoder(self, x, mask_ratio): """Perform a forward pass through the encoder only. Args: x (torch.Tensor): Input images. mask_ratio (float): Fraction of patches to mask. Returns: torch.Tensor: Encoded features from the encoder. """ return self.autoencoder.forward_encoder(x, mask_ratio=mask_ratio)
[docs] def on_validation_epoch_end(self): """Called at the end of the validation epoch. Aggregates validation metrics, logs them to the logger (WandB or default), and clears the metrics buffer. """ merged_metrics = bench_recon.merge_metrics(self.validation_metrics) batch_metrics = bench_recon.mean_metrics(merged_metrics) if isinstance(self.logger, pl.loggers.wandb.WandbLogger): # this only occurs on rank zero only df = pd.DataFrame(batch_metrics) df["mean"] = df.mean(numeric_only=True, axis=1) df["metric"] = df.index cols = df.columns.tolist() self.logger.log_table( key="val_reconstruction", dataframe=df[cols[-1:] + cols[:-1]], step=self.validation_step, ) for k, v in batch_metrics.items(): for i, j in v.items(): self.log(f"val_{k}_{i}", j) else: for k in batch_metrics.keys(): batch_metrics[k]["channel"] = k for k, v in batch_metrics.items(): self.log_dict(v, sync_dist=True) # reset self.validation_metrics.clear()
[docs] def test_step(self, batch, batch_idx): """Perform a single test step. Args: batch: A tuple containing (images, timestamps). batch_idx: The index of the current batch. """ x, timestamps = batch[:2] loss, x_hat, mask = self.autoencoder(x, mask_ratio=self.masking_ratio) x_hat_reconstructed = unpatchify(x_hat, self.img_size, self.patch_size, self.tubelet_size) # Transfer both tensors to CPU once, outside all loops x_np = x.detach().cpu().numpy() # [B, C, T, H, W] x_hat_np = x_hat_reconstructed.detach().cpu().numpy() batch_size = mask.shape[0] num_frames = x.shape[2] if self.limb_mask is not None: # Build full-resolution mask once for the entire batch: [B, 512, 512] grid_size = self.img_size // self.patch_size mask_full = ( mask.reshape(batch_size, grid_size, grid_size) .detach() .cpu() .numpy() .repeat(self.patch_size, axis=1) .repeat(self.patch_size, axis=2) .astype(bool) ) step_metrics = [ bench_recon.get_metrics_for_masked_patches( x_np[i, :, frame, mask_full[i]], x_hat_np[i, :, frame, mask_full[i]], self.chan_types, ) for i in range(batch_size) for frame in range(num_frames) ] x_patchified = patchify(x, self.patch_size, self.tubelet_size) active_mask = mask == 1 masked_mse = F.mse_loss(x_patchified[active_mask], x_hat[active_mask]) else: step_metrics = [ bench_recon.get_metrics( x_np[i, :, frame, :, :], x_hat_np[i, :, frame, :, :], self.chan_types, ) for i in range(batch_size) for frame in range(num_frames) ] self.test_results.extend(step_metrics) self.log("test_loss", loss) if self.limb_mask is not None: self.log("test_MSEloss_in_masked_patches", masked_mse)
[docs] def on_test_epoch_end(self): """Called at the end of the test epoch. Aggregates test metrics, saves them to a CSV file, logs to the logger, and clears the results buffer. """ if not self.test_results: return merged_metrics = bench_recon.merge_metrics(self.test_results) batch_metrics = bench_recon.mean_metrics(merged_metrics) df = pd.DataFrame(batch_metrics) df["mean"] = df.mean(numeric_only=True, axis=1) df["metric"] = df.index cols = df.columns.tolist() final_df = df[[cols[-1]] + cols[:-1]] output_path = os.path.join(self.trainer.default_root_dir, "test_metrics_summary.csv") final_df.to_csv(output_path, index=False) print(f"\n[INFO] Test results saved to: {output_path}") if isinstance(self.logger, pl.loggers.wandb.WandbLogger): self.logger.log_table(key="test_reconstruction_summary", dataframe=final_df) for chan, metrics in batch_metrics.items(): for m_name, val in metrics.items(): self.log(f"test_{chan}_{m_name}", val) self.test_results.clear()