Source code for sdofmv2.tasks.missing_data.necks
"""Modified from https://github.com/NASA-IMPACT/hls-foundation-os/geospatial_fm/geospatial_fm.py"""
import torch
import torch.nn as nn
def _convTranspose2dOutput(
input_size: int,
stride: int,
padding: int,
dilation: int,
kernel_size: int,
output_padding: int,
):
"""Calculate the output size of a ConvTranspose2d layer.
Args:
input_size (int): The input size.
stride (int): The stride of the convolution.
padding (int): The padding added to the input.
dilation (int): The spacing between kernel elements.
kernel_size (int): The size of the kernel.
output_padding (int): Additional size added to one side of the output.
Returns:
int: The output size after the transposed convolution.
Reference:
https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
"""
return (
(input_size - 1) * stride
- 2 * padding
+ dilation * (kernel_size - 1)
+ output_padding
+ 1
)
[docs]
class Norm2d(nn.Module):
"""Applies Layer Normalization over 4D inputs (channels-last).
This module reshapes the input from (B, C, H, W) to (B, H, W, C),
applies LayerNorm, and then reshapes it back to (B, C, H, W).
Args:
embed_dim (int): The number of features in the input (C).
"""
def __init__(self, embed_dim: int):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
[docs]
def forward(self, x):
"""Apply layer normalization to the input.
Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W).
Returns:
torch.Tensor: Output tensor of shape (B, C, H, W) with normalized features.
"""
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x
[docs]
class ConvTransformerTokensToEmbeddingNeck(nn.Module):
"""
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
"""
def __init__(
self,
embed_dim: int,
output_embed_dim: int,
num_frames: int = 1,
Hp: int = 14,
Wp: int = 14,
drop_cls_token: bool = True,
):
"""
Args:
embed_dim (int): Input embedding dimension
output_embed_dim (int): Output embedding dimension
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True.
"""
super().__init__()
self.drop_cls_token = drop_cls_token
self.Hp = Hp
self.Wp = Wp
self.H_out = Hp
self.W_out = Wp
# self.num_frames = num_frames
# To create full patch dimension
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# FOR deconvolution
kernel_size = 2
stride = 2
dilation = 1
padding = 0
output_padding = 0
for _ in range(4):
self.H_out = _convTranspose2dOutput(
self.H_out, stride, padding, dilation, kernel_size, output_padding
)
self.W_out = _convTranspose2dOutput(
self.W_out, stride, padding, dilation, kernel_size, output_padding
)
self.embed_dim = embed_dim * num_frames
self.output_embed_dim = output_embed_dim
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(
self.embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
Norm2d(self.output_embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
Norm2d(self.output_embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
)
[docs]
def forward(self, x, ids_restore):
"""Transform transformer tokens back to spatial embeddings.
Args:
x (torch.Tensor): Token embeddings of shape (batch, num_tokens, token_dim).
ids_restore (torch.Tensor): Indices for reordering tokens to original positions.
Returns:
torch.Tensor: Upsampled embeddings of shape (batch, out_channels, 16*H, 16*W).
"""
# x.shape --> [batch, num_tokken inner disk, token dim]
if self.drop_cls_token:
x = x[:, 1:, :]
B, L, D = x.shape
mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - L, 1)
x = torch.cat([x, mask_tokens], dim=1) # no cls token
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# x.shape --> [batch, total num token, token dim]
x = x.permute(0, 2, 1).reshape(
B, -1, self.Hp, self.Wp
) # [batch, token_dim, H, W]
x = self.fpn1(x) # [batch, out channels, 4*H, 4*W]
x = self.fpn2(x) # [batch, out channels, 16*H, 16*W]
return x