Downstream App: Missing Data
Data Module
- class sdofmv2.tasks.missing_data.missing_data_module.MissingDataModel(*args: Any, **kwargs: Any)[source]
Bases:
BaseModuleA model for reconstructing missing data channels using a backbone autoencoder.
This class wraps a backbone autoencoder to perform missing data tasks. It implements a random channel drop mechanism where one channel is zeroed out, and the model is trained to reconstruct that specific channel using MSE loss.
- Parameters:
optimizer_dict (dict, optional) – Configuration for the optimizer. Defaults to None.
scheduler_dict (dict, optional) – Configuration for the learning rate scheduler. Defaults to None.
backbone (object, optional) – The backbone autoencoder model. Defaults to None.
freeze_encoder (bool) – Whether to freeze the encoder blocks of the backbone. Defaults to True.
*args – Variable length argument list passed to BaseModule.
**kwargs – Arbitrary keyword arguments passed to BaseModule.
- backbone
The underlying autoencoder model.
- Type:
object
- masking_ratio
The masking ratio used by the backbone.
- Type:
float
- validation_metrics
Storage for metrics computed during validation.
- Type:
list
Methods
forward(imgs[, mask_ratio])Performs a standard forward pass through the backbone autoencoder.
forward_random_channel_drop(imgs[, mask_ratio])Corrupts a random channel and performs a forward pass.
training_step(batch, batch_idx)Executes a single training step with random channel corruption.
validation_step(batch, batch_idx)Executes a single validation step and logs metrics.
- forward(imgs, mask_ratio=0.5)[source]
Performs a standard forward pass through the backbone autoencoder.
- Parameters:
imgs (torch.Tensor) – Input images of shape (B, C, T, H, W).
mask_ratio (float) – Ratio of patches to mask. Defaults to 0.5.
- Returns:
- A tuple containing:
loss (torch.Tensor): The reconstruction loss.
x_hat (torch.Tensor): The unpatchified reconstructed images.
mask (torch.Tensor): The mask applied during the forward pass.
- Return type:
tuple
- forward_random_channel_drop(imgs, mask_ratio=0.75)[source]
Corrupts a random channel and performs a forward pass.
- Parameters:
imgs (torch.Tensor) – Input images of shape (B, C, T, H, W).
mask_ratio (float) – Ratio of patches to mask. Defaults to 0.75.
- Returns:
- A tuple containing:
loss (torch.Tensor): The reconstruction loss.
x_hat (torch.Tensor): The reconstructed images.
mask (torch.Tensor): The mask applied during the forward pass.
target_idx (int): The index of the channel that was zeroed out.
- Return type:
tuple
- training_step(batch, batch_idx)[source]
Executes a single training step with random channel corruption.
- Parameters:
batch (tuple) – A tuple containing (images, timestamps).
batch_idx (int) – The index of the current batch.
- Returns:
The MSE loss calculated on the dropped channel.
- Return type:
torch.Tensor
Neck Module
- class sdofmv2.tasks.missing_data.necks.Norm2d(embed_dim: int)[source]
Bases:
ModuleApplies Layer Normalization over 4D inputs (channels-last).
This module reshapes the input from (B, C, H, W) to (B, H, W, C), applies LayerNorm, and then reshapes it back to (B, C, H, W).
- Parameters:
embed_dim (int) – The number of features in the input (C).
Methods
forward(x)Apply layer normalization to the input.
- class sdofmv2.tasks.missing_data.necks.ConvTransformerTokensToEmbeddingNeck(embed_dim: int, output_embed_dim: int, num_frames: int = 1, Hp: int = 14, Wp: int = 14, drop_cls_token: bool = True)[source]
Bases:
ModuleNeck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
Methods
forward(x, ids_restore)Transform transformer tokens back to spatial embeddings.
- forward(x, ids_restore)[source]
Transform transformer tokens back to spatial embeddings.
- Parameters:
x (torch.Tensor) – Token embeddings of shape (batch, num_tokens, token_dim).
ids_restore (torch.Tensor) – Indices for reordering tokens to original positions.
- Returns:
Upsampled embeddings of shape (batch, out_channels, 16*H, 16*W).
- Return type:
torch.Tensor
Wrap Encoder
- class sdofmv2.tasks.missing_data.wrap_encoder.WrapEncoder(encoder: Module)[source]
Bases:
ModuleA wrapper for Prithvi-style encoders to handle temporal dimensions.
This class ensures that 4D input tensors (B, C, H, W) are correctly reshaped into 5D tensors (B, C, T, H, W) before being passed to the encoder. It also manages the extraction of intermediate features.
- encoder
The underlying encoder module (e.g., a Prithvi ViT).
Methods
forward(x)Performs a forward pass through the encoder.
forward_features(x, n[, mask_ratio, ...])Extracts intermediate features from specific layers of the encoder.
- forward(x: Tensor) Tensor[source]
Performs a forward pass through the encoder.
If the input is 4D, a singleton temporal dimension is added. The temporal dimension is squeezed from the output before returning.
- Parameters:
x – Input tensor of shape (B, C, H, W) or (B, C, T, H, W).
- Returns:
The encoded features as a torch.Tensor.
- forward_features(x: Tensor, n: list[int], mask_ratio: float = 0.0, reshape: bool = True, norm: bool = False) list[Tensor][source]
Extracts intermediate features from specific layers of the encoder.
- Parameters:
x – Input tensor of shape (B, C, H, W) or (B, C, T, H, W).
n – A list of layer indices from which to extract features.
mask_ratio – The fraction of patches to mask during the forward pass.
reshape – Whether to reshape the output features into a spatial grid.
norm – Whether to apply normalization to the extracted features.
- Returns:
A list of tensors containing the intermediate features from the requested layers.