"""Module containing the 3D Masked Autoencoder (MAE) architecture.
This module implements a Vision Transformer (ViT) based Masked Autoencoder
adapted for 3D spatiotemporal data (e.g., sequences of solar images). It
includes tools for 3D patch embedding and masked reconstruction using
custom solar limb masks.
Adapted from: https://github.com/facebookresearch/mae/blob/main/models_mae.py
"""
import torch
import torch.nn as nn
from timm.layers import to_2tuple
from timm.models.vision_transformer import Block
import torch.nn.functional as F
from ..utils import get_3d_sincos_pos_embed, unpatchify, patchify
from .losses import (
mae_loss,
vector_aware_loss,
pixel_weight_loss,
patch_weight_loss,
split_pixel_loss,
sparse_dense_loss,
)
class PatchEmbed(nn.Module):
"""Frames of 2D Images to 3D Patch Embedding.
This is a 3D spatiotemporal version of `timm.models.vision_transformer.PatchEmbed`.
It divides a sequence of frames into 3D tubelets (patches across space and time)
and projects them into an embedding space.
Args:
img_size (int or tuple): Spatial dimensions of the input images. If an int
is provided, it assumes a square image (height, width).
patch_size (int or tuple): Spatial dimensions of each patch. If an int
is provided, it assumes square patches.
num_frames (int): Number of frames in the input sequence.
tubelet_size (int): Temporal dimension of each patch.
in_chans (int): Number of channels in the input images.
embed_dim (int): Output (token) dimension after the 3D convolution layer.
norm_layer (torch.nn.Module, optional): Normalization layer applied after
the convolution. Defaults to None.
flatten (bool): Whether to flatten the spatial and temporal dimensions
into a single sequence.
bias (bool): Whether to use a bias term in the 3D convolution layer.
Attributes:
grid_size (tuple): The calculated grid dimensions of the patches
(time, height, width).
num_patches (int): Total number of patches generated per sequence.
proj (nn.Conv3d): The 3D convolution layer used for projection.
"""
def __init__(
self,
img_size=224,
patch_size=16,
num_frames=3,
tubelet_size=1,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_size = (
num_frames // tubelet_size,
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
stride=(tubelet_size, patch_size[0], patch_size[1]),
bias=bias,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, T, H, W = x.shape # batch channels frames height width
# print("input dim", x.shape)
x = self.proj(x)
# print("proj dim", x.shape)
# The output size is (B, L, C), where N=H*W/T/T, C is embid_dim
if self.flatten:
x = x.flatten(2).transpose(1, 2).contiguous() # B,C,T,H,W -> B,C,L=(T*H*W) -> B,L,C
x = self.norm(x)
return x
[docs]
class MaskedAutoencoderViT3D(nn.Module):
"""Masked Autoencoder with a 3D Vision Transformer backbone.
This model encodes a fraction of visible image patches and predicts
the missing (masked) patches. It incorporates specialized masking logic
to handle solar limb boundaries, ensuring the model focuses on the solar disk.
Args:
img_size (int): Spatial dimensions of the input images.
Defaults to 224.
patch_size (int): Spatial dimensions of each patch.
Defaults to 16.
num_frames (int): Number of frames in the input temporal sequence.
Defaults to 3.
tubelet_size (int): Temporal dimension of each patch (tubelet).
Defaults to 1.
in_chans (int): Number of input channels (e.g., wavelengths).
Defaults to 3.
embed_dim (int): Dimensionality of the token embeddings in the
encoder. Defaults to 1024.
depth (int): Number of Transformer blocks in the encoder.
Defaults to 24.
num_heads (int): Number of attention heads in the encoder's
Transformer blocks. Defaults to 16.
decoder_embed_dim (int): Dimensionality of the token embeddings
in the decoder. Defaults to 512.
decoder_depth (int): Number of Transformer blocks in the decoder.
Defaults to 8.
decoder_num_heads (int): Number of attention heads in the decoder's
Transformer blocks. Defaults to 16.
mlp_ratio (float): Ratio of the hidden dimension to the embedding
dimension in the MLP of the Transformer blocks. Defaults to 4.0.
norm_layer (str): The type of normalization layer to use.
Currently supports "LayerNorm". Defaults to "LayerNorm".
limb_mask (torch.Tensor): A 2D spatial mask tensor outlining the
solar limb, used for pixel-level loss computation. Defaults to None.
ids_limb_mask (torch.Tensor or list): 1D array of patch indices
that fall outside the solar disk. These are always masked out during
training. Defaults to None.
loss_dict (dict): Configuration dictionary specifying the loss
space (e.g., "patch" or "pixel"), loss type (e.g., "mse", "mae"), and
other loss-specific hyperparameters. Defaults to an empty dict.
"""
def __init__(
self,
img_size=224,
patch_size=16,
num_frames=3,
tubelet_size=1,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0,
norm_layer="LayerNorm",
limb_mask=None,
ids_limb_mask=None,
loss_dict={},
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.limb_mask = limb_mask
self.loss_dict = loss_dict
# define loss # define loss
loss_functions = {
"mse": F.mse_loss,
"mae": mae_loss,
"huber": F.huber_loss,
"vector_aware_loss": vector_aware_loss,
"pixel_weight_loss": pixel_weight_loss,
"patch_weight_loss": patch_weight_loss,
"split_patch_loss": split_pixel_loss,
"sparse_dense_loss": sparse_dense_loss,
}
loss_type = self.loss_dict.get("type", "mse")
if loss_type not in loss_functions:
raise ValueError(f"Unknown loss: {loss_type}")
self.loss_fn = loss_functions[loss_type]
# define normalization
match norm_layer:
case "LayerNorm":
norm_layer = nn.LayerNorm
case _:
raise NotImplementedError(f"Norm layer [{norm_layer}] not implemented.")
# --------------------------------------------------------------------------
# Limb masking
if ids_limb_mask is not None:
ids_limb_mask = torch.as_tensor(ids_limb_mask, dtype=torch.long)
ids_limb_mask, _ = torch.sort(ids_limb_mask)
self.register_buffer("ids_limb_mask", ids_limb_mask) # indices of pixel inside solar disk
# MAE encoder specifics
self.patch_embed = PatchEmbed(
img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False
) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList(
[
Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(
decoder_embed_dim,
tubelet_size * patch_size * patch_size * in_chans,
bias=True,
) # decoder to patch
# --------------------------------------------------------------------------
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_3d_sincos_pos_embed(
self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_3d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=0.02)
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs]
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
Args:
x (torch.Tensor): Input sequence of embedded patches. Expected shape
is [N, L, D], where N is batch size, L is sequence length
(number of patches), and D is embedding dimension.
mask_ratio (float): The fraction of patches to mask out (e.g., 0.75).
If a limb mask is provided, this fraction only applies to patches
inside the solar disk.
Returns:
tuple: A tuple containing:
- x_masked (torch.Tensor): The sequence of kept, visible patches.
- mask (torch.Tensor): Binary mask indicating kept (0) vs. removed (1) patches.
- ids_restore (torch.Tensor): Indices needed to unshuffle the patches
back to their original spatial order.
"""
N, L, D = x.shape # batch, length, dim
device = x.device
# if using a solar limb mask, we're not interested in predicting outside the
# disk. For that we override these kept ids to include always keeping the mask
# as we don't want to learn that.
if self.ids_limb_mask is not None:
ids_limb = self.ids_limb_mask.to(device).long().flatten()
# Get indices of patches INSIDE the solar disk
all_indices = torch.arange(L, device=device)
mask_limb = torch.zeros(L, dtype=torch.bool, device=device)
mask_limb[ids_limb] = True
# Indices inside the disk (where we can do masking)
ids_inside_disk = all_indices[~mask_limb]
num_inside = len(ids_inside_disk)
# Calculate how many inside patches to keep
len_keep_inside = int(num_inside * (1 - mask_ratio))
# Generate random noise only for inside patches
noise_inside = torch.rand(N, num_inside, device=device)
ids_shuffle_local = torch.argsort(noise_inside, dim=1)
# Convert local indices back to global indices
ids_inside_expand = ids_inside_disk.unsqueeze(0).expand(N, -1)
local_keep = ids_shuffle_local[:, :len_keep_inside]
ids_keep = torch.gather(ids_inside_expand, dim=1, index=local_keep)
# Gather the kept patches
ids_keep_gather = ids_keep.unsqueeze(-1).expand(-1, -1, D)
x_masked = torch.gather(x, dim=1, index=ids_keep_gather)
# Create binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=device, dtype=torch.long)
mask.scatter_(1, ids_keep, 0)
# Get the removed indices for all samples at once
# limb ids repeated per-sample (always removed)
ids_limb_expand = ids_limb.unsqueeze(0).expand(N, -1) # [N, num_limb]
local_removed_inside = ids_shuffle_local[:, len_keep_inside:]
ids_removed_inside = torch.gather(
ids_inside_expand, dim=1, index=local_removed_inside
) # [N, num_removed_inside]
ids_removed = torch.cat([ids_limb_expand, ids_removed_inside], dim=1)
# Concatenate kept and removed for all samples
# ids_keep: [N, num_kept], ids_removed: [N, num_removed]
full_order = torch.cat([ids_keep, ids_removed], dim=1) # [N, L]
# Apply argsort to all samples at once
ids_restore = torch.argsort(full_order, dim=1) # [N, L]
else:
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=device, dtype=torch.long)
mask.scatter_(1, ids_keep, 0)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# print("patch_embed dim", x.shape)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
[docs]
def forward(self, imgs, mask_ratio=0.5):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
# Compute patch-level loss
if self.loss_dict.space == "patch":
target = patchify(imgs, self.patch_size, self.tubelet_size)
if self.loss_dict.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target_norm = (target - mean) / (var + 1e-6) ** 0.5
else:
target_norm = target
device = target.device
batch_size, num_patches, dims = pred.shape
# Identify off-limb regions from config
if self.ids_limb_mask is not None and len(self.ids_limb_mask) > 0:
is_off_limb = torch.zeros(
(batch_size, num_patches), device=device, dtype=torch.bool
)
is_off_limb[:, self.ids_limb_mask] = True
else:
is_off_limb = None
if self.loss_dict.type == "pixel_weight_loss":
loss = self.loss_fn(
pred,
target_norm,
target,
self.loss_dict.pixel_weight_loss.base_loss,
self.loss_dict.pixel_weight_loss.threshold,
self.loss_dict.pixel_weight_loss.ar_weight_ratio,
)
# patch weight loss only works when only_masked_patches: false and exclude limb: true/false
elif self.loss_dict.type == "patch_weight_loss":
loss = self.loss_fn(
pred,
target_norm,
self.loss_dict.patch_weight_loss,
mask_hidden=mask,
mask_off_limb=is_off_limb,
)
elif self.loss_dict.type == "split_patch_loss":
loss = self.loss_fn(
pred,
target_norm if self.loss_dict.norm_pix_loss else target,
alpha=self.loss_dict.split_patch_loss.get("alpha", 1.0),
beta=self.loss_dict.split_patch_loss.get("beta", 1.0),
base_type=self.loss_dict.split_patch_loss.get("base_type", "mse"),
huber_delta=self.loss_dict.split_patch_loss.get("huber_delta", 1.0),
imgs=imgs,
patch_size=self.patch_size,
corner_size=self.loss_dict.split_patch_loss.get("corner_size", 4),
corner_ratio=self.loss_dict.split_patch_loss.get("corner_ratio", 0.25),
)
elif self.loss_dict.type == "sparse_dense_loss":
loss = self.loss_fn(
pred,
target_norm if self.loss_dict.norm_pix_loss else target,
alpha=self.loss_dict.sparse_dense_loss.get("alpha", 1.0),
beta=self.loss_dict.sparse_dense_loss.get("beta", 1.0),
base_type=self.loss_dict.sparse_dense_loss.get("base_type", "mse"),
huber_delta=self.loss_dict.sparse_dense_loss.get("huber_delta", 1.0),
imgs=imgs,
patch_size=self.patch_size,
corner_size=self.loss_dict.split_patch_loss.get("corner_size", 4),
corner_ratio=self.loss_dict.sparse_dense_loss.get("corner_ratio", 0.25),
)
else:
loss = self.loss_fn(
pred,
target_norm if self.loss_dict.norm_pix_loss else target,
)
elif self.loss_dict.space == "pixel":
# pixel-level reconstruction
pred_img = unpatchify(pred, self.img_size, self.patch_size, self.tubelet_size)
# Apply limb mask if available
if self.loss_dict.exclude_limb:
limb_mask = self.limb_mask.to(imgs.device)
pred_img = pred_img * limb_mask
imgs = imgs * limb_mask
# Compute pixel-level loss
if self.loss_dict.type == "vector_aware_loss":
loss = self.loss_fn(pred_img, imgs, self.loss_dict.vector_aware_loss.base_loss)
else:
loss = self.loss_fn(pred_img, imgs)
else:
raise ValueError("Unsupported loss space type")
return loss, pred, mask