import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
def cos_sin_transformation(
position: torch.Tensor, max_power: int = 4, include_raw_coordinates=False
) -> torch.Tensor:
"""Generate harmonic (Fourier) positional encodings.
Args:
position (torch.Tensor): Position coordinates of shape (batch, position_size).
max_power (int, optional): Maximum power for the Fourier features. Defaults to 4.
include_raw_coordinates (bool, optional): Whether to include raw coordinates
in the output. Defaults to False.
Returns:
torch.Tensor: Positional encodings of shape (batch, 2 * position_size * (max_power + 1))
or with raw coordinates appended.
"""
if position.ndim == 1:
position = position.unsqueeze(0) # [1, 4]
powers = 2.0 ** torch.arange(
max_power + 1, device=position.device, dtype=position.dtype
) # [5]
scaled_pos = position.unsqueeze(-1) * powers # [B, 4, 5]
scaled_pos = scaled_pos.view(position.size(0), -1) # [B, 20]
cos_vec = torch.cos(scaled_pos)
sin_vec = torch.sin(scaled_pos)
concat = torch.cat([cos_vec, sin_vec], dim=-1)
if include_raw_coordinates:
concat = torch.cat([concat, position])
return concat # [B, 40]
# class Transformer
[docs]
class SimpleLinear(nn.Module):
"""Multi-Layer Perceptron head for flattened feature processing.
This head flattens the input feature map and concatenates it with harmonic
positional encodings (sine/cosine) and radial distance before passing
the combined vector through a non-linear MLP.
Args:
d_output (int): Dimension of the final output.
input_feature_dim (int): Dimension of the input features before flattening.
max_position_element (int): Highest power used in sine/cosine
transform. Defaults to 4.
position_size (int): Number of raw coordinate variables. Defaults to 4.
hidden_dim (int): Width of the first hidden layer. Defaults to 16.
p_drop (float): Dropout probability. Defaults to 0.1.
Attributes:
d_output (int): Output dimensionality.
input_feature_dim (int): Dimension of the input features.
hidden_dim (int): Hidden layer width.
p_drop (float): Dropout rate used in the network.
max_position_element (int): Complexity of the harmonic encoding.
position_size (int): Number of coordinate inputs.
network (nn.Sequential): The MLP architecture (Linear -> ReLU -> Dropout).
"""
def __init__(
self,
d_output,
input_feature_dim,
max_position_element=4,
position_size=4,
hidden_dim=16,
p_drop=0.1,
):
super().__init__()
self.d_output = d_output
self.input_feature_dim = input_feature_dim
self.network = None
self.hidden_dim = hidden_dim
self.p_drop = p_drop
self.max_position_element = max_position_element
self.position_size = position_size
# Calculate total input dimension
# 2 (cos and sine), 4 (position: psp location and footpoints), 5 exponents (from 0 to max_position_element)
r_distance_dim = 1 # assuming scalar
pos_encoding_dim = 2 * position_size * (max_position_element + 1)
total_input_dim = input_feature_dim + pos_encoding_dim + r_distance_dim
self.network = nn.Sequential(
nn.Linear(total_input_dim, self.hidden_dim),
nn.ReLU(),
nn.Dropout(p=self.p_drop),
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
nn.ReLU(),
nn.Dropout(p=self.p_drop),
nn.Linear(self.hidden_dim // 2, self.d_output),
)
self._init_weights()
def _init_weights(self):
"""Initialize weights for linear layers using Xavier uniform initialization."""
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
[docs]
def forward(self, x, position, r_distance):
"""Perform a forward pass through the MLP head.
Args:
x (torch.Tensor): Input features of shape (batch, ...).
position (torch.Tensor): Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor): Radial distance values of shape (batch,).
Returns:
torch.Tensor: Output logits of shape (batch, d_output).
"""
batch_size = x.size(0)
# Flatten input while preserving batch dimension
if x.dim() > 2:
x = x.view(batch_size, -1)
# Get position encoding
pos_encoded = cos_sin_transformation(
position, max_power=self.max_position_element
)
pos_encoded = pos_encoded.detach().clone().to(dtype=x.dtype, device=x.device)
r_tensor = r_distance.to(dtype=x.dtype, device=x.device)
# Ensure position encoding matches batch size
if pos_encoded.size(0) != batch_size:
pos_encoded = pos_encoded.expand(batch_size, -1)
# Concatenate and process
r_tensor = r_tensor.view(batch_size, 1) if r_tensor.ndim == 1 else r_tensor
combined = torch.cat([x, pos_encoded, r_tensor.reshape(batch_size, -1)], dim=-1)
return self.network(combined)
[docs]
class SkipLinearHead(nn.Module):
"""Deep MLP head with skip connections for coordinate-aware regression.
This module implements a deep architecture where the initial concatenated
input (features + encodings) is re-injected at specified layers. This
residual-style connection helps maintain high-frequency coordinate
information throughout the depth of the network.
Args:
d_output (int): Dimension of the final output.
input_feature_dim (int): Dimension of a single input frame.
max_position_element (int): Highest power for harmonic encoding. Defaults to 4.
position_size (int): Number of raw coordinate variables. Defaults to 4.
hidden_dim (int): Latent width of the hidden layers. Defaults to 16.
skips (list[int]): Layer indices where the initial input is concatenated
back into the hidden state. Defaults to [4].
include_raw_coordinates (bool): If True, appends non-encoded coordinates
to the input vector. Defaults to False.
num_hidden_layers (int): Total number of linear layers in the backbone.
Defaults to 8.
number_of_frames (int): Number of temporal frames to flatten. Defaults to 1.
Attributes:
d_output (int): Output dimensionality.
hidden_dim (int): Latent width of the network.
max_position_element (int): Harmonic encoding complexity.
position_size (int): Number of coordinate inputs.
skips (list[int]): Indices of layers performing skip-connections.
num_hidden_layers (int): Total depth of the MLP.
pts_linears (nn.ModuleList): Collection of linear layers including
skip-connection logic.
output_linear (nn.Linear): Final layer mapping to output dimension.
"""
def __init__(
self,
d_output,
input_feature_dim,
max_position_element=4,
position_size=4,
hidden_dim=16,
skips=[4],
include_raw_coordinates=False,
num_hidden_layers=8,
number_of_frames=1,
):
super().__init__()
self.d_output = d_output
self.network = None
self.hidden_dim = hidden_dim
self.max_position_element = max_position_element
self.position_size = position_size
self.skips = skips
self.num_hidden_layers = num_hidden_layers
# Calculate total input dimension
# 2 (cos and sine), 4 (position: psp location and footpoints), 5 exponents (from 0 to max_position_element)
r_distance_dim = 1 # assuming scalar
pos_encoding_dim = 2 * position_size * (max_position_element + 1)
total_input_dim = (
input_feature_dim * number_of_frames + pos_encoding_dim + r_distance_dim
)
if include_raw_coordinates:
total_input_dim = total_input_dim + position_size
self.pts_linears = nn.ModuleList(
[nn.Linear(total_input_dim, hidden_dim)]
+ [
(
nn.Linear(hidden_dim, hidden_dim)
if i not in self.skips
else nn.Linear(hidden_dim + total_input_dim, hidden_dim)
)
for i in range(self.num_hidden_layers - 1)
]
)
self.output_linear = nn.Linear(hidden_dim, d_output)
self._init_weights()
def _init_weights(self):
"""Initialize weights for linear layers using Xavier uniform initialization."""
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
[docs]
def forward(self, x, position, r_distance):
"""Perform a forward pass through the skip-connected MLP head.
Args:
x (torch.Tensor): Input features of shape (batch, ...).
position (torch.Tensor): Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor): Radial distance values of shape (batch,).
Returns:
torch.Tensor: Output logits of shape (batch, d_output).
"""
batch_size = x.size(0)
# Flatten input while preserving batch dimension
if x.dim() > 2:
x = x.view(batch_size, -1)
# Get position encoding
pos_encoded = cos_sin_transformation(
position, max_power=self.max_position_element
)
pos_encoded = pos_encoded.detach().clone().to(dtype=x.dtype, device=x.device)
r_tensor = r_distance.to(dtype=x.dtype, device=x.device)
# Ensure position encoding matches batch size
if pos_encoded.size(0) != batch_size:
pos_encoded = pos_encoded.expand(batch_size, -1)
# Concatenate and process
r_tensor = r_tensor.view(batch_size, 1) if r_tensor.ndim == 1 else r_tensor
combined_net = torch.cat(
[x, pos_encoded, r_tensor.reshape(batch_size, -1)], dim=-1
)
init_net = combined_net
for i, l in enumerate(self.pts_linears):
combined_net = self.pts_linears[i](combined_net)
combined_net = F.relu(combined_net)
if i in self.skips:
combined_net = torch.cat([combined_net, init_net], -1)
outputs = self.output_linear(combined_net)
return outputs
[docs]
class ClsLinear(nn.Module):
"""MLP head designed for Transformer [CLS] token representations.
This module specifically extracts the class (CLS) token from the first
index of the input sequence and combines it with physical metadata
(positional encodings and radial distance) for final prediction.
Args:
d_output (int): Dimension of the final output.
embedding_dim (int): Dimensionality of the tokens in the input sequence.
max_position_element (int): Highest power for harmonic encoding. Defaults to 4.
position_size (int): Number of raw coordinate variables. Defaults to 4.
hidden_dim (int): Hidden width of the MLP. Defaults to 16.
p_drop (float): Dropout probability. Defaults to 0.1.
Attributes:
d_output (int): Output dimensionality.
embedding_dim (int): Dimension of the input tokens (D).
hidden_dim (int): Hidden layer width.
p_drop (float): Dropout rate.
max_position_element (int): Complexity of the harmonic encoding.
position_size (int): Number of coordinate inputs.
network (nn.Sequential): MLP layers processing the combined CLS and
metadata vector.
"""
def __init__(
self,
d_output,
embedding_dim,
max_position_element=4,
position_size=4,
hidden_dim=16,
p_drop=0.1,
):
super().__init__()
self.d_output = d_output
self.embedding_dim = embedding_dim
self.network = None
self.hidden_dim = hidden_dim
self.p_drop = p_drop
self.max_position_element = max_position_element
self.position_size = position_size
# Calculate total input dimension
# 2 (cos and sine), 4 (position: psp location and footpoints), 5 exponents (from 0 to max_position_element)
r_distance_dim = 1 # assuming scalar
pos_encoding_dim = 2 * position_size * (max_position_element + 1)
total_input_dim = embedding_dim + pos_encoding_dim + r_distance_dim
self.network = nn.Sequential(
nn.Linear(total_input_dim, self.hidden_dim),
nn.ReLU(),
nn.Dropout(p=self.p_drop),
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
nn.ReLU(),
nn.Dropout(p=self.p_drop),
nn.Linear(self.hidden_dim // 2, self.d_output),
)
self._init_weights()
def _init_weights(self):
"""Initialize weights for linear layers using Xavier uniform initialization."""
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
[docs]
def forward(self, x, position, r_distance):
"""Perform a forward pass using the CLS token.
Args:
x (torch.Tensor): Input sequence with CLS token at index 0,
shape (batch, num_tokens, embed_dim).
position (torch.Tensor): Position coordinates of shape (batch, position_size).
r_distance (torch.Tensor): Radial distance values of shape (batch,).
Returns:
torch.Tensor: Output logits of shape (batch, d_output).
"""
batch_size = x.size(0)
cls_tokens = x[:, 0, :] # [1, B]
pos_encoded = cos_sin_transformation(
position, max_power=self.max_position_element
)
pos_encoded = pos_encoded.detach().clone().to(dtype=x.dtype, device=x.device)
r_tensor = r_distance.to(dtype=x.dtype, device=x.device)
# Ensure position encoding matches batch size
if pos_encoded.size(0) != batch_size:
pos_encoded = pos_encoded.expand(batch_size, -1)
# Concatenate and process
r_tensor = r_tensor.view(batch_size, 1) if r_tensor.ndim == 1 else r_tensor
combined = torch.cat(
[cls_tokens, pos_encoded, r_tensor.reshape(batch_size, -1)], dim=-1
)
return self.network(combined)