Source code for sdofmv2.tasks.missing_data.necks

"""Modified from https://github.com/NASA-IMPACT/hls-foundation-os/geospatial_fm/geospatial_fm.py"""

import torch
import torch.nn as nn


def _convTranspose2dOutput(
    input_size: int,
    stride: int,
    padding: int,
    dilation: int,
    kernel_size: int,
    output_padding: int,
):
    """Calculate the output size of a ConvTranspose2d layer.

    Args:
        input_size (int): The input size.
        stride (int): The stride of the convolution.
        padding (int): The padding added to the input.
        dilation (int): The spacing between kernel elements.
        kernel_size (int): The size of the kernel.
        output_padding (int): Additional size added to one side of the output.

    Returns:
        int: The output size after the transposed convolution.

    Reference:
        https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
    """
    return (
        (input_size - 1) * stride
        - 2 * padding
        + dilation * (kernel_size - 1)
        + output_padding
        + 1
    )


[docs] class Norm2d(nn.Module): """Applies Layer Normalization over 4D inputs (channels-last). This module reshapes the input from (B, C, H, W) to (B, H, W, C), applies LayerNorm, and then reshapes it back to (B, C, H, W). Args: embed_dim (int): The number of features in the input (C). """ def __init__(self, embed_dim: int): super().__init__() self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
[docs] def forward(self, x): """Apply layer normalization to the input. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W). Returns: torch.Tensor: Output tensor of shape (B, C, H, W) with normalized features. """ x = x.permute(0, 2, 3, 1) x = self.ln(x) x = x.permute(0, 3, 1, 2).contiguous() return x
[docs] class ConvTransformerTokensToEmbeddingNeck(nn.Module): """ Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2 """ def __init__( self, embed_dim: int, output_embed_dim: int, num_frames: int = 1, Hp: int = 14, Wp: int = 14, drop_cls_token: bool = True, ): """ Args: embed_dim (int): Input embedding dimension output_embed_dim (int): Output embedding dimension Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True. """ super().__init__() self.drop_cls_token = drop_cls_token self.Hp = Hp self.Wp = Wp self.H_out = Hp self.W_out = Wp # self.num_frames = num_frames # To create full patch dimension self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # FOR deconvolution kernel_size = 2 stride = 2 dilation = 1 padding = 0 output_padding = 0 for _ in range(4): self.H_out = _convTranspose2dOutput( self.H_out, stride, padding, dilation, kernel_size, output_padding ) self.W_out = _convTranspose2dOutput( self.W_out, stride, padding, dilation, kernel_size, output_padding ) self.embed_dim = embed_dim * num_frames self.output_embed_dim = output_embed_dim self.fpn1 = nn.Sequential( nn.ConvTranspose2d( self.embed_dim, self.output_embed_dim, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding, ), Norm2d(self.output_embed_dim), nn.GELU(), nn.ConvTranspose2d( self.output_embed_dim, self.output_embed_dim, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding, ), ) self.fpn2 = nn.Sequential( nn.ConvTranspose2d( self.output_embed_dim, self.output_embed_dim, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding, ), Norm2d(self.output_embed_dim), nn.GELU(), nn.ConvTranspose2d( self.output_embed_dim, self.output_embed_dim, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding, ), )
[docs] def forward(self, x, ids_restore): """Transform transformer tokens back to spatial embeddings. Args: x (torch.Tensor): Token embeddings of shape (batch, num_tokens, token_dim). ids_restore (torch.Tensor): Indices for reordering tokens to original positions. Returns: torch.Tensor: Upsampled embeddings of shape (batch, out_channels, 16*H, 16*W). """ # x.shape --> [batch, num_tokken inner disk, token dim] if self.drop_cls_token: x = x[:, 1:, :] B, L, D = x.shape mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - L, 1) x = torch.cat([x, mask_tokens], dim=1) # no cls token x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) # x.shape --> [batch, total num token, token dim] x = x.permute(0, 2, 1).reshape( B, -1, self.Hp, self.Wp ) # [batch, token_dim, H, W] x = self.fpn1(x) # [batch, out channels, 4*H, 4*W] x = self.fpn2(x) # [batch, out channels, 16*H, 16*W] return x