Downstream App: Missing Data

Data Module

class sdofmv2.tasks.missing_data.missing_data_module.MissingDataModel(*args: Any, **kwargs: Any)[source]

Bases: BaseModule

A model for reconstructing missing data channels using a backbone autoencoder.

This class wraps a backbone autoencoder to perform missing data tasks. It implements a random channel drop mechanism where one channel is zeroed out, and the model is trained to reconstruct that specific channel using MSE loss.

Parameters:
  • optimizer_dict (dict, optional) – Configuration for the optimizer. Defaults to None.

  • scheduler_dict (dict, optional) – Configuration for the learning rate scheduler. Defaults to None.

  • backbone (object, optional) – The backbone autoencoder model. Defaults to None.

  • freeze_encoder (bool) – Whether to freeze the encoder blocks of the backbone. Defaults to True.

  • *args – Variable length argument list passed to BaseModule.

  • **kwargs – Arbitrary keyword arguments passed to BaseModule.

backbone

The underlying autoencoder model.

Type:

object

masking_ratio

The masking ratio used by the backbone.

Type:

float

validation_metrics

Storage for metrics computed during validation.

Type:

list

Methods

forward(imgs[, mask_ratio])

Performs a standard forward pass through the backbone autoencoder.

forward_random_channel_drop(imgs[, mask_ratio])

Corrupts a random channel and performs a forward pass.

training_step(batch, batch_idx)

Executes a single training step with random channel corruption.

validation_step(batch, batch_idx)

Executes a single validation step and logs metrics.

forward(imgs, mask_ratio=0.5)[source]

Performs a standard forward pass through the backbone autoencoder.

Parameters:
  • imgs (torch.Tensor) – Input images of shape (B, C, T, H, W).

  • mask_ratio (float) – Ratio of patches to mask. Defaults to 0.5.

Returns:

A tuple containing:
  • loss (torch.Tensor): The reconstruction loss.

  • x_hat (torch.Tensor): The unpatchified reconstructed images.

  • mask (torch.Tensor): The mask applied during the forward pass.

Return type:

tuple

forward_random_channel_drop(imgs, mask_ratio=0.75)[source]

Corrupts a random channel and performs a forward pass.

Parameters:
  • imgs (torch.Tensor) – Input images of shape (B, C, T, H, W).

  • mask_ratio (float) – Ratio of patches to mask. Defaults to 0.75.

Returns:

A tuple containing:
  • loss (torch.Tensor): The reconstruction loss.

  • x_hat (torch.Tensor): The reconstructed images.

  • mask (torch.Tensor): The mask applied during the forward pass.

  • target_idx (int): The index of the channel that was zeroed out.

Return type:

tuple

training_step(batch, batch_idx)[source]

Executes a single training step with random channel corruption.

Parameters:
  • batch (tuple) – A tuple containing (images, timestamps).

  • batch_idx (int) – The index of the current batch.

Returns:

The MSE loss calculated on the dropped channel.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Executes a single validation step and logs metrics.

Parameters:
  • batch (tuple) – A tuple containing (images, timestamps).

  • batch_idx (int) – The index of the current batch.

Neck Module

class sdofmv2.tasks.missing_data.necks.Norm2d(embed_dim: int)[source]

Bases: Module

Applies Layer Normalization over 4D inputs (channels-last).

This module reshapes the input from (B, C, H, W) to (B, H, W, C), applies LayerNorm, and then reshapes it back to (B, C, H, W).

Parameters:

embed_dim (int) – The number of features in the input (C).

Methods

forward(x)

Apply layer normalization to the input.

forward(x)[source]

Apply layer normalization to the input.

Parameters:

x (torch.Tensor) – Input tensor of shape (B, C, H, W).

Returns:

Output tensor of shape (B, C, H, W) with normalized features.

Return type:

torch.Tensor

class sdofmv2.tasks.missing_data.necks.ConvTransformerTokensToEmbeddingNeck(embed_dim: int, output_embed_dim: int, num_frames: int = 1, Hp: int = 14, Wp: int = 14, drop_cls_token: bool = True)[source]

Bases: Module

Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2

Methods

forward(x, ids_restore)

Transform transformer tokens back to spatial embeddings.

forward(x, ids_restore)[source]

Transform transformer tokens back to spatial embeddings.

Parameters:
  • x (torch.Tensor) – Token embeddings of shape (batch, num_tokens, token_dim).

  • ids_restore (torch.Tensor) – Indices for reordering tokens to original positions.

Returns:

Upsampled embeddings of shape (batch, out_channels, 16*H, 16*W).

Return type:

torch.Tensor

Wrap Encoder

class sdofmv2.tasks.missing_data.wrap_encoder.WrapEncoder(encoder: Module)[source]

Bases: Module

A wrapper for Prithvi-style encoders to handle temporal dimensions.

This class ensures that 4D input tensors (B, C, H, W) are correctly reshaped into 5D tensors (B, C, T, H, W) before being passed to the encoder. It also manages the extraction of intermediate features.

encoder

The underlying encoder module (e.g., a Prithvi ViT).

Methods

forward(x)

Performs a forward pass through the encoder.

forward_features(x, n[, mask_ratio, ...])

Extracts intermediate features from specific layers of the encoder.

forward(x: Tensor) Tensor[source]

Performs a forward pass through the encoder.

If the input is 4D, a singleton temporal dimension is added. The temporal dimension is squeezed from the output before returning.

Parameters:

x – Input tensor of shape (B, C, H, W) or (B, C, T, H, W).

Returns:

The encoded features as a torch.Tensor.

forward_features(x: Tensor, n: list[int], mask_ratio: float = 0.0, reshape: bool = True, norm: bool = False) list[Tensor][source]

Extracts intermediate features from specific layers of the encoder.

Parameters:
  • x – Input tensor of shape (B, C, H, W) or (B, C, T, H, W).

  • n – A list of layer indices from which to extract features.

  • mask_ratio – The fraction of patches to mask during the forward pass.

  • reshape – Whether to reshape the output features into a spatial grid.

  • norm – Whether to apply normalization to the extracted features.

Returns:

A list of tensors containing the intermediate features from the requested layers.