Source code for lrnnx.models.lti.s5

"""
Basic S5 SSM.
Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
"""

import math
from typing import Any, Dict, Literal, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from lrnnx.core.convolution import opt_ssm_forward
from lrnnx.models.lti.base import LTI_LRNN


[docs] class S5(LTI_LRNN): """ Basic S5 State Space Model. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks Example: >>> model = S5(d_model=64, d_state=64, discretization="zoh") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64]) """
[docs] def __init__( self, d_model: int, d_state: int, # this is P in the paper, the actual state dimension of the system. discretization: Literal[ "zoh", "bilinear", "dirac", "no_discretization", # diff discretization methods available to the user. ], conj_sym: bool = False, # if True, uses conjugate symmetry for the state space model. ): """ Initialize S5 model. Args: d_model (int): Model dimension. d_state (int): State dimension (P in the original paper). discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]): Discretization method to use. conj_sym (bool, optional): If True, uses conjugate symmetry for the state space model. Defaults to False. """ super().__init__(discretization=discretization) self.d_model = d_model self.d_state = d_state # just keeping this incase the tests use that self.hid_dim = d_model self.state_dim = d_state if conj_sym: raise NotImplementedError( "Conjugate symmetry is not implemented yet." ) self.conj_sym = conj_sym # if True, uses conjugate symmetry for the state space model. self._init_parameters()
def _init_parameters(self): """ Initializes the parameters of the S5 model. This method sets up the system matrix A, input matrix B, time step log_dt, and output weight matrix C. """ init_parameter = lambda mat: Parameter( torch.tensor(mat, dtype=torch.float) ) normal_parameter = lambda fan_in, shape: Parameter( torch.randn(*shape) * math.sqrt(2 / fan_in) ) H = self.d_model # hidden dimension (input) N = self.d_state # state dimension (inner dimension for the SSM) # Creatiing real and imaginary parts of A A_real = 0.5 * np.ones(N) A_imag = math.pi * np.arange(N) # Stored in log for numerical stability. log_A_real = np.log(np.exp(A_real) - 1) # inverse softplus # Stacking into complex diagonal format (Re, Im) A = np.stack([log_A_real, A_imag], axis=-1) # log spaced time scale log_dt = np.linspace(np.log(0.001), np.log(0.1), N) B = np.ones((N, H)) / math.sqrt(H) # shape: (N, H) self.A = init_parameter(A) self.B = init_parameter(B) self.log_dt = init_parameter(log_dt) self.C = normal_parameter(N, (H, N, 2)) self.D = normal_parameter(H, (H, H)) # output projection matrix
[docs] def discretize( self, ) -> tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]: """ Discretizes the continuous-time system matrices A and B using the specified discretization method. Returns: tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]: A tuple containing: - A_bar : Discretized system matrix A, shape ``(N,)``. - gamma_bar : Input normalizer, shape ``(N,)`` or a float. - C_complex : Complex output matrix C, shape ``(H, N)``. """ log_A_real, A_imag = self.A.T # (2, state_dim) dt = self.log_dt.exp() # log time steps converted to real time. # They are stored in log-space during training as it provides numerical stability and allows the model # to learn a wide range of temporal scales, from very fast (small dt) to very slow (large dt) dynamics # Continous time complex A matrix A_complex = -F.softplus(log_A_real) + 1j * A_imag # shape: (N,) # Discretize (LTI - no integration_timesteps) A_bar, gamma_bar = self.discretize_fn( A_complex, dt, None ) # (N,), (N,) # also prepare C matrix C_complex = self.C[..., 0] + 1j * self.C[..., 1] # (H, N) return A_bar, gamma_bar, C_complex
[docs] def compute_kernel( self, L: int, A_bar: Tensor, gamma_bar: Union[Tensor, float], ): """ Computes the kernel matrices for the S5 model: A^t and B_bar. Args: L (int): Length of the input sequence. A_bar (torch.Tensor): Discretized system matrix A, shape ``(N,)``. gamma_bar (Union[torch.Tensor, float]): Input normalizer, shape ``(N,)`` or a float. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - A_power : Power of the discretized system matrix A, shape ``(N, L)``. - B_bar : Normalized input projection matrix, shape ``(N, H)``. """ # Compute B_bar if isinstance(gamma_bar, float): B_bar = gamma_bar * self.B # (N, H) else: assert ( gamma_bar.dim() == 1 ), f"gamma_bar should be 1D tensor, got {gamma_bar.dim()}D tensor" B_bar = gamma_bar.unsqueeze(-1) * self.B # Compute A^t lrange = torch.arange(L, device=A_bar.device) A_power = A_bar[:, None] ** lrange[None, :] return A_power, B_bar
[docs] def forward( self, x: Tensor, integration_timesteps: Optional[Tensor] = None, lengths: Optional[Tensor] = None, ) -> Tensor: """ Forward pass of the S5 SSM using FFT-based convolution. Args: x (torch.Tensor): Input tensor of shape ``(B, L, H)``. integration_timesteps (torch.Tensor, optional): Not used by S5 (LTI model). Kept for interface compatibility with LTV models. Defaults to None. lengths (torch.Tensor, optional): Lengths of the input sequences, shape ``(B,)``. TODO: Support bidirectional models. Defaults to None. Returns: torch.Tensor: Output tensor of shape ``(B, L, H)``. """ if x.dim() != 3: raise ValueError( f"Input tensor must be of shape (B, L, H), got {x.dim()}D tensor with shape {x.shape}" ) L = x.shape[1] A_bar, B_bar, C_complex = self.discretize() K, B_hat = self.compute_kernel(L, A_bar, B_bar) return opt_ssm_forward(x, K, B_hat, C_complex) + x @ self.D
[docs] def step( self, x: torch.Tensor, inference_cache: Dict[str, Any], **kwargs, ) -> tuple[torch.Tensor, Dict[str, Any]]: """ Performs a single recurrent step of the S5 model. Args: x (torch.Tensor): Input at current time step, shape ``(B, H)``. inference_cache (Dict[str, Any]): Cache from ``allocate_inference_cache()`` containing "lrnn_state" and pre-computed matrices. **kwargs: Additional keyword arguments. Returns: tuple[torch.Tensor, Dict[str, Any]]: Output y_t of shape ``(B, H)`` and updated cache dictionary. """ if x.dim() != 2: raise ValueError( f"Input tensor must be of shape (B, H), got {x.dim()}D tensor with shape {x.shape}" ) state = inference_cache["lrnn_state"] # Extract cached matrices A_bar = inference_cache["A_bar"] B_bar = inference_cache["B_bar"] C_complex = inference_cache["C_complex"] # Recurrent update: x_t -> state_{t+1} # state_{t+1} = A_bar * state_t + B_bar @ u_t input_projection = torch.einsum("bh,nh->bn", x.to(B_bar.dtype), B_bar) new_state = A_bar * state + input_projection # Output computation: y_t = C @ state_t + D * u_t state_output = torch.einsum("hn,bn->bh", C_complex, new_state) y = state_output.real + x @ self.D # (B, H) inference_cache["lrnn_state"].copy_(new_state) return y, inference_cache
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int = 1, dtype=None, **kwargs, ) -> Dict[str, Any]: """ Allocates cache for inference. Args: batch_size (int): The batch size for the input data. max_seqlen (int, optional): Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1. dtype (torch.dtype, optional): Data type for allocated tensors (unused). Defaults to None. **kwargs: Additional model-specific arguments. Returns: Dict[str, Any]: Cache dict with "lrnn_state" and pre-computed discrete matrices. """ # Initialize state to zeros device = self.A.device initial_state = torch.zeros( batch_size, self.d_state, dtype=torch.complex64, device=device ) # Pre-compute and cache matrices (LTI - compute once) A_bar, gamma_bar, C_complex = self.discretize() if isinstance(gamma_bar, float): B_bar = gamma_bar * self.B.to(torch.complex64) else: assert ( gamma_bar.dim() == 1 ), f"gamma_bar should be 1D tensor, got {gamma_bar.dim()}D tensor" B_bar = gamma_bar.unsqueeze(-1) * self.B return { "lrnn_state": initial_state, "A_bar": A_bar, "B_bar": B_bar, "C_complex": C_complex, }