"""Classifier using Linear RNN models with support for token embeddings.
Reference: https://github.com/Efficient-Scalable-Machine-Learning/event-ssm
"""
from typing import List, Literal, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
from torch import Tensor
from lrnnx.architectures.embedding import TokenEmbedding
from lrnnx.models.lti.centaurus import Centaurus
from lrnnx.models.lti.lru import LRU
from lrnnx.models.lti.s5 import S5
def _get_mixer_class_from_string(mixer_type: str):
"""Helper to convert string names to model classes."""
mixer_registry = {
"LRU": LRU,
"S5": S5,
"Centaurus": Centaurus,
# "Mamba": Mamba,
}
if mixer_type not in mixer_registry:
raise ValueError(
f"Unknown mixer type: {mixer_type}. Available: {list(mixer_registry.keys())}"
)
return mixer_registry[mixer_type]
[docs]
class SequencePooling(nn.Module):
"""
Pooling layer for sequence data with support for variable lengths.
Handles both intermediate pooling (reducing sequence length) and
final pooling (creating a single vector representation).
"""
[docs]
def __init__(self, pooling_type="last", stride=1):
"""
Initialize the pooling layer.
Args:
pooling_type (str): Pooling mode ("last", "mean", "max", "stride")
stride (int): Stride for pooling (only used for intermediate pooling)
"""
super().__init__()
self.pooling_type = pooling_type
self.stride = stride
[docs]
def forward(
self,
x: Tensor,
lengths: Optional[Tensor] = None,
integration_timesteps: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""
Pool sequences, either reducing length (intermediate) or to a single vector (final).
Args:
x (torch.Tensor): Input of shape ``(B, L, D)``.
lengths (torch.Tensor, optional): Actual sequence lengths of shape ``(B,)``.
Defaults to None.
integration_timesteps (torch.Tensor, optional): Timesteps of shape ``(B, L)``.
Defaults to None.
Returns:
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: Pooled tensor (and updated
timesteps / lengths for intermediate pooling).
"""
B, L, D = x.shape
# Intermediate pooling (reducing sequence length)
if self.stride > 1:
if self.pooling_type == "stride":
x = x[:, :: self.stride, :]
if integration_timesteps is not None:
integration_timesteps = integration_timesteps[
:, :: self.stride
]
if lengths is not None:
lengths = torch.ceil(lengths.float() / self.stride).long()
max_len = x.shape[1]
lengths = torch.clamp(lengths, max=max_len)
elif self.pooling_type in ["mean", "max"]:
# Mask for valid tokens
if lengths is not None:
mask = torch.arange(L, device=x.device).unsqueeze(
0
) < lengths.unsqueeze(1)
mask = mask.float().unsqueeze(2) # (B, L, 1)
x = x * mask
x_pooled = x.transpose(1, 2) # (B, L, D) -> (B, D, L)
if self.pooling_type == "mean":
x_pooled = nn.functional.avg_pool1d(
x_pooled,
kernel_size=self.stride,
stride=self.stride,
)
else: # "max"
# For max, set padded values to a large negative number so they don't affect max
if lengths is not None:
x_masked = x + (mask - 1) * 1e9 # (B, L, D)
x_pooled = x_masked.transpose(1, 2)
x_pooled = nn.functional.max_pool1d(
x_pooled,
kernel_size=self.stride,
stride=self.stride,
)
x = x_pooled.transpose(1, 2) # (B, D, L') -> (B, L', D)
# Update timesteps by summing over pooling windows
if integration_timesteps is not None:
ts_unfolded = integration_timesteps.unfold(
1, self.stride, self.stride
)
integration_timesteps = ts_unfolded.sum(dim=2)
if lengths is not None:
# Pool the mask to get new valid lengths
mask = mask.transpose(1, 2) # (B, 1, L)
pooled_mask = nn.functional.avg_pool1d(
mask,
kernel_size=self.stride,
stride=self.stride,
)
pooled_mask = pooled_mask.transpose(1, 2) # (B, L', 1)
# New lengths: count windows with any valid token
new_lengths = (pooled_mask.squeeze(2) > 0).sum(dim=1)
lengths = new_lengths
else:
raise ValueError(
f"Unknown intermediate pooling strategy: {self.pooling_type}"
)
return x, integration_timesteps, lengths
# Final pooling (sequence -> single vector)
else:
if self.pooling_type == "last":
if lengths is not None:
# Use actual last timestep for variable-length sequences
batch_indices = torch.arange(B, device=x.device)
last_indices = torch.clamp(lengths - 1, 0, L - 1)
pooled = x[batch_indices, last_indices, :]
else:
pooled = x[:, -1, :]
elif self.pooling_type == "mean":
if lengths is not None:
# Masked mean for variable-length sequences
mask = torch.arange(L, device=x.device).unsqueeze(
0
) < lengths.unsqueeze(1)
mask = mask.unsqueeze(2).float()
pooled = (x * mask).sum(dim=1) / torch.clamp(
lengths.unsqueeze(1).float(), min=1
)
else:
pooled = x.mean(dim=1)
elif self.pooling_type == "max":
if lengths is not None:
# Masked max for variable-length sequences
mask = torch.arange(L, device=x.device).unsqueeze(
0
) < lengths.unsqueeze(1)
mask = mask.unsqueeze(2).float()
masked_x = x * mask + (1 - mask) * (-1e9)
pooled = masked_x.max(dim=1)[0]
else:
pooled = x.max(dim=1)[0]
else:
raise ValueError(
f"Unknown pooling strategy: {self.pooling_type}"
)
return pooled, integration_timesteps, lengths
[docs]
class ClassifierBlock(nn.Module):
"""
A single processing block in the Classifier.
Each block contains:
- LRNN layer for temporal processing (instantiated from lrnn_cls)
- Optional intermediate pooling for sequence length reduction
- Dropout for regularization
- Residual connection
- Layer normalization
"""
[docs]
def __init__(
self,
d_model,
d_state,
lrnn_cls: Type[nn.Module],
num_classes: int = 0,
output_dim: int = 1,
pooling: Literal["mean", "last", "max"] = "last",
dropout: float = 0.1,
intermediate_pooling: Literal[
"none", "stride", "mean", "max"
] = "none",
pooling_factor: int = 2,
is_final: bool = False,
**lrnn_params,
):
"""
Initialize a processing block used inside the classifier.
The block performs sequence processing and when is_final=True
produces a single output vector. Set num_classes > 0 to enable
classification (the block returns logits over num_classes); otherwise
the block produces regression outputs of shape output_dim.
"""
super().__init__()
# Instantiate LRNN layer directly from lrnn_params.
# The user must provide all required constructor arguments in lrnn_params.
# Examples:
# - LRU: lrnn_params={"d_model": d_model, "d_state": d_state}
# - S5: lrnn_params={"d_model": d_model, "d_state": d_state, "discretization": "zoh"}
# - Centaurus: lrnn_params={"d_model": d_model, "d_state": d_state, "sub_state_dim": d_state}
# - Mamba: lrnn_params={"d_model": d_model, "d_state": d_state}
try:
self.lrnn = lrnn_cls(**lrnn_params)
except TypeError as e:
raise TypeError(
f"Could not instantiate {getattr(lrnn_cls, '__name__', str(lrnn_cls))} "
f"with provided lrnn_params: {lrnn_params}. "
f"Ensure you pass all required constructor arguments for the LRNN class. "
f"Error: {e}"
)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.intermediate_pooling = intermediate_pooling
self.pooling_factor = pooling_factor
self.is_final = is_final
# Create pooling layer if needed. Annotate as Optional to allow None assignment.
self.pooler: Optional[SequencePooling] = None
if intermediate_pooling != "none":
self.pooler = SequencePooling(
pooling_type=intermediate_pooling, stride=pooling_factor
)
# Final pooling and output head for the last block
self.final_pooler: Optional[SequencePooling] = None
if is_final:
self.final_pooler = SequencePooling(pooling_type=pooling)
if num_classes > 0:
self.output_proj = nn.Linear(d_model, num_classes)
else:
self.output_proj = nn.Linear(d_model, output_dim)
[docs]
def forward(
self,
x: Tensor,
integration_timesteps: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
) -> Union[Tensor, Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
"""
Forward pass through a single classifier block.
Args:
x (torch.Tensor): Input of shape ``(B, L, D)``.
integration_timesteps (torch.Tensor, optional): Timesteps for LTV models.
Defaults to None.
lengths (torch.Tensor, optional): Actual sequence lengths.
Defaults to None.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Final block returns logits ``(B, num_classes)``;
non-final blocks return ``(x, integration_timesteps, lengths)``.
"""
# Standard block processing
x_res = x
x = self.lrnn(x, integration_timesteps, lengths)
x = self.dropout(x)
x = self.norm(x + x_res)
# Apply intermediate pooling if specified
if self.intermediate_pooling != "none" and self.pooler is not None:
x, integration_timesteps, lengths = self.pooler(
x, lengths, integration_timesteps
)
if self.is_final:
# Final pooling and output head
# final_pooler is Optional but only set when is_final is True
assert self.final_pooler is not None
pooled, _, _ = self.final_pooler(x, lengths, integration_timesteps)
return self.output_proj(pooled)
else:
return x, integration_timesteps, lengths
[docs]
class Classifier(nn.Module):
"""
Classifier: Sequence classifier or regressor...
Args:
input_dim (int): Number of input features.
num_classes (int): Number of output classes.
d_model (int): Hidden dimension of the model.
"""
[docs]
def __init__(
self,
input_dim: int,
num_classes: int = 0,
output_dim: int = 1,
d_model: int = 128,
d_state: int = 64,
n_layers: int = 4,
lrnn_cls: Union[Type[nn.Module], List[Type[nn.Module]]] = LRU,
pooling: Literal["mean", "last", "max"] = "last",
dropout: float = 0.1,
intermediate_pooling: Union[
Literal["none", "stride", "mean", "max"],
List[Literal["none", "stride", "mean", "max"]],
] = "none",
pooling_factor: Union[int, List[int]] = 2,
vocab_size: Optional[int] = None,
embedding_dim: Optional[int] = None,
max_position_embeddings: Optional[int] = None,
padding_idx: Optional[int] = 0,
lrnn_params: Optional[dict] = None,
):
"""
Initializes the Classifier.
Args:
input_dim (int): Number of input features (ignored when vocab_size is provided).
num_classes (int, optional): Number of output classes. Defaults to 0.
output_dim (int, optional): Number of regression outputs. Defaults to 1.
d_model (int, optional): Hidden dimension of the model. Defaults to 128.
d_state (int, optional): State dimension for the LRNN layers. Defaults to 64.
n_layers (int, optional): Number of LRNN layers. Defaults to 4.
lrnn_cls (type | list[type], optional): Custom LRNN class or list of classes
(one per layer) to use. Defaults to LRU.
pooling (str, optional): Pooling strategy for sequence outputs. Defaults to ``"last"``.
dropout (float, optional): Dropout probability. Defaults to 0.1.
intermediate_pooling (str | list[str], optional): Pooling strategy for
each layer. Defaults to ``"none"``.
pooling_factor (int | list[int], optional): Factor by which to reduce
sequence length. Defaults to 2.
vocab_size (int, optional): Size of vocabulary for token embeddings. Defaults to None.
embedding_dim (int, optional): Dimension of embeddings (defaults to d_model). Defaults to None.
max_position_embeddings (int, optional): Max sequence length for positional
embeddings. Defaults to None.
padding_idx (int, optional): Index of padding token for embedding layer. Defaults to 0.
lrnn_params (dict, optional): Additional parameters for LRNN modules. Defaults to None.
"""
super().__init__()
self.d_model = d_model
self.pooling = pooling
# Determine if using token embeddings
if vocab_size is not None:
# Create token embedding layer
emb_dim = embedding_dim if embedding_dim is not None else d_model
self.embedding = TokenEmbedding(
vocab_size=vocab_size,
embedding_dim=emb_dim,
padding_idx=padding_idx,
max_position_embeddings=max_position_embeddings,
use_position=False, # set True if you want learned positional embeddings
dropout=dropout,
)
self.has_embedding = True
# Project embeddings to model dimension if needed
self.embed_proj = (
nn.Linear(emb_dim, d_model)
if emb_dim != d_model
else nn.Identity()
) # type: nn.Module
else:
# For raw features, use standard projection
self.has_embedding = False
self.input_proj = nn.Linear(input_dim, d_model)
# Handle pooling configuration - cast to proper types
if isinstance(intermediate_pooling, str):
intermediate_pooling_list: List[
Literal["none", "stride", "mean", "max"]
] = [intermediate_pooling] * n_layers
else:
intermediate_pooling_list = intermediate_pooling
if isinstance(pooling_factor, int):
pooling_factor_list = [pooling_factor] * n_layers
else:
pooling_factor_list = pooling_factor
lrnn_params = lrnn_params or {}
# Normalize lrnn_cls to a list
if not isinstance(lrnn_cls, list):
lrnn_cls = [lrnn_cls] * n_layers
if len(lrnn_cls) != n_layers:
raise ValueError(
f"lrnn_cls list length ({len(lrnn_cls)}) must match n_layers ({n_layers})"
)
# Convert any strings to classes using the helper
lrnn_cls_list = []
for item in lrnn_cls:
if isinstance(item, str):
lrnn_cls_list.append(_get_mixer_class_from_string(item))
else:
lrnn_cls_list.append(item)
# Stack of blocks (all but last have no output head)
self.blocks = nn.ModuleList()
for i in range(n_layers - 1):
self.blocks.append(
ClassifierBlock(
d_model=d_model,
d_state=d_state,
num_classes=0,
output_dim=0,
pooling=pooling,
lrnn_cls=lrnn_cls_list[i],
dropout=dropout,
intermediate_pooling=intermediate_pooling_list[i],
pooling_factor=pooling_factor_list[i],
is_final=False,
**lrnn_params,
)
)
# Last block has output head
self.final_block = ClassifierBlock(
d_model=d_model,
d_state=d_state,
num_classes=num_classes,
output_dim=output_dim,
pooling=pooling,
lrnn_cls=lrnn_cls_list[-1],
dropout=dropout,
intermediate_pooling=intermediate_pooling_list[-1],
pooling_factor=pooling_factor_list[-1],
is_final=True,
**lrnn_params,
)
[docs]
def forward(
self,
x: Tensor,
lengths: Optional[Tensor] = None,
integration_timesteps: Optional[Tensor] = None,
) -> Union[Tensor, Tuple[Tensor, Optional[Tensor], Optional[Tensor]]]:
"""
Forward pass of the classifier/regressor.
Args:
x (torch.Tensor): Input tensor. Token IDs of shape ``(B, L)`` when using embeddings,
or continuous features of shape ``(B, L, input_dim)`` otherwise.
lengths (torch.Tensor, optional): Actual sequence lengths of shape ``(B,)``.
Defaults to None.
integration_timesteps (torch.Tensor, optional): Timesteps of shape ``(B, L)``
for LTV models. Defaults to None.
Returns:
torch.Tensor: Logits of shape ``(B, num_classes)`` or regression values
of shape ``(B, output_dim)``.
"""
# Process input based on type and model configuration
if self.has_embedding:
# If input is token ids (B, L) -> embed then project
if x.dim() == 2:
x = self.embedding(x)
x = self.embed_proj(x)
else:
# Already-embedded inputs (B, L, D) -> project with embed_proj
x = self.embed_proj(x)
else:
# Raw continuous inputs -> project from input_dim -> d_model
x = self.input_proj(x)
# Pass through all but last block (non-final blocks must return '(x, integration_timesteps, lengths)').
for block in self.blocks:
x, integration_timesteps, lengths = block(
x, integration_timesteps, lengths
)
# Final block returns output
return self.final_block(x, integration_timesteps, lengths)