Source code for lrnnx.models.ltv.s5

"""
S5 SSM with CUDA kernel acceleration.
Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
"""

from typing import Any, Dict, Literal, Optional, Tuple

import numpy as np
import torch
from torch import Tensor
from torch.nn import Parameter

from lrnnx.models.ltv.base import LTV_LRNN
from lrnnx.ops.simplified_scan import s5_inner_fn, simplified_scan_fn
from lrnnx.utils.init import init_CV, init_log_steps, init_VinvB, make_DPLR_HiPPO

try:
    from lrnnx.ops.triton.simplified_state_update import simplified_state_update
except ImportError:
    simplified_state_update = None  # type: ignore[assignment]


[docs] class S5(LTV_LRNN): """ S5 SSM with CUDA kernel acceleration. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks Example: >>> model = S5(d_model=64, d_state=64) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64]) """
[docs] def __init__( self, d_model: int, d_state: int, discretization: Literal["bilinear", "zoh", "dirac"] = "zoh", conj_sym: bool = False, dt_min: float = 0.001, dt_max: float = 0.1, step_rescale: float = 1.0, use_fast_path: bool = True, device=None, dtype=None, ): """ Initialize S5 model. Args: d_model (int): Model dimension. d_state (int): State dimension. discretization (Literal["bilinear", "zoh", "dirac"], optional): Discretization method. Defaults to ``"zoh"``. conj_sym (bool, optional): If True, uses conjugate symmetry for the state space model. Defaults to False. dt_min (float, optional): Minimum value for dt initialization. Defaults to 0.001. dt_max (float, optional): Maximum value for dt initialization. Defaults to 0.1. step_rescale (float, optional): Rescale factor for step size. Defaults to 1.0. use_fast_path (bool, optional): Whether to use fused CUDA kernels. Defaults to True. device (torch.device, optional): Device for parameters. Defaults to None. dtype (torch.dtype, optional): Data type for parameters. Defaults to None. """ super().__init__(discretization=discretization) self.d_model = d_model self.d_state = d_state self.conj_sym = conj_sym self.use_fast_path = use_fast_path self.discretization = discretization self.step_rescale = step_rescale H, P = d_model, d_state if conj_sym: Lambda, _, _, V, _ = make_DPLR_HiPPO(2 * P) Lambda = Lambda[:P] V = V[:, :P] Vinv = V.conj().T local_P = 2 * P else: Lambda, _, _, V, _ = make_DPLR_HiPPO(P) Vinv = np.linalg.inv(V) local_P = P A = np.stack([Lambda.real, Lambda.imag], axis=-1) self.A = Parameter(torch.tensor(A, dtype=torch.float32, device=device)) self.B = Parameter(init_VinvB(Vinv, local_P, H).to(device=device)) self.C = Parameter(init_CV(V, local_P, H).to(device=device)) self.D = Parameter(torch.randn(H, device=device) * 0.1) self.log_dt = Parameter( init_log_steps(P, dt_min, dt_max).to(device=device) )
[docs] def forward( self, x: Tensor, integration_timesteps: Optional[Tensor] = None, lengths: Optional[Tensor] = None, inference_cache: Optional[Dict[str, Any]] = None, ) -> Tensor: """ Forward pass through S5. Args: x (torch.Tensor): Input tensor of shape ``(B, L, H)``. integration_timesteps (torch.Tensor, optional): Timesteps for async/event-driven discretization. Defaults to None. lengths (torch.Tensor, optional): Lengths of sequences, required for variable-length sequences. Defaults to None. inference_cache (Dict[str, Any], optional): Cache for autoregressive generation. Defaults to None. Returns: torch.Tensor: Output tensor of shape ``(B, L, H)``. """ if x.dim() != 3: raise ValueError(f"Expected 3D input (B, L, H), got {x.dim()}D") batch, seqlen, _ = x.shape P = self.d_state if inference_cache is not None: seqlen_offset = inference_cache.get("seqlen_offset", 0) if seqlen_offset > 0: out, inference_cache = self.step(x, inference_cache) return out A = torch.complex(self.A[..., 0], self.A[..., 1]) B_tilde = torch.complex(self.B[..., 0], self.B[..., 1]) C_tilde = torch.complex(self.C[..., 0], self.C[..., 1]) dt = self.log_dt.exp() * self.step_rescale x_t = x.transpose(1, 2) u = x_t.to(torch.complex64) delta = ( dt.unsqueeze(0).unsqueeze(-1).expand(batch, P, seqlen).contiguous() ) deltaA = None if integration_timesteps is not None: # Async/event mode: use separate deltaA for A discretization # integration_timesteps: (B, L) or (B, L, 1) if integration_timesteps.dim() == 3: integration_timesteps = integration_timesteps.squeeze(-1) deltaA = ( integration_timesteps.to(delta.dtype) .unsqueeze(1) .expand(batch, P, seqlen) * dt.unsqueeze(0).unsqueeze(-1) ).contiguous() if self.use_fast_path and x.is_cuda: y = s5_inner_fn( u, delta, A, B_tilde, C_tilde, self.D.float(), deltaA=deltaA, discretization=self.discretization, conj_sym=self.conj_sym, ) else: # Slow path: use simplified_scan_fn + manual post-processing y_complex = simplified_scan_fn( u, delta, A, B_tilde, C_tilde, deltaA=deltaA, discretization=self.discretization, return_last_state=False, ) # Apply conjugate symmetry if self.conj_sym: y = 2.0 * y_complex.real # type: ignore[union-attr] else: y = y_complex.real # type: ignore[union-attr] # Skip connection: D * real(u) u_real = u.real if u.is_complex() else u y = y + self.D.float().unsqueeze(0).unsqueeze(-1) * u_real return y.transpose(1, 2).to(x.dtype)
[docs] def step( # type: ignore[override] self, x: Tensor, inference_cache: Dict[str, Any], integration_timesteps: Optional[Tensor] = None, ) -> Tuple[Tensor, Dict[str, Any]]: """ Performs a single recurrent step of S5. When the simplified_state_update Triton kernel is available and the tensors live on CUDA, the state is updated **in-place** via the kernel (which also fuses discretization, input projection, and output projection into a single launch). Otherwise a pure-PyTorch fallback is used. Args: x (torch.Tensor): Input at current timestep, shape ``(B, 1, H)`` or ``(B, H)``. inference_cache (Dict[str, Any]): Cache dictionary containing SSM state and continuous-time parameters. integration_timesteps (torch.Tensor, optional): Optional per-step integration timesteps for event/async mode, shape ``(B,)`` or ``(B, 1)``. Defaults to None. Returns: tuple[torch.Tensor, Dict[str, Any]]: A tuple containing: - y : Output tensor at the current timestep. - inference_cache : Updated cache dictionary. """ if x.dim() == 3: x = x.squeeze(1) state = inference_cache[ "ssm_state" ] # (B, P) complex64, mutated in-place A = inference_cache["A"] # (P,) complex64 B_tilde = inference_cache["B_tilde"] # (P, H) complex64 C_tilde = inference_cache["C_tilde"] # (H, P) complex64 dt = inference_cache["dt"] # (P,) float32 # Optional per-step deltaA for event/async mode deltaA = None if integration_timesteps is not None: if integration_timesteps.dim() > 1: timestep = integration_timesteps.view(-1, 1) else: timestep = integration_timesteps.unsqueeze(-1) deltaA = timestep.to(dt.dtype) * dt.unsqueeze(0) # (B, P) if simplified_state_update is not None and x.is_cuda: # Triton fast-path (in-place state update) y = simplified_state_update( state, x.float(), dt, A, B_tilde, C_tilde, D=None, # D is 1-D here; handled below deltaA=deltaA, discretization=self.discretization, conj_sym=self.conj_sym, ) else: dtA = ( deltaA if deltaA is not None else dt.unsqueeze(0).expand(x.shape[0], -1) ) if self.discretization == "bilinear": half_dtA_A = 0.5 * dtA * A A_bar = (1.0 + half_dtA_A) / (1.0 - half_dtA_A) gamma_bar = dt / (1.0 - 0.5 * dt * A) elif self.discretization == "zoh": A_bar = torch.exp(dtA * A) gamma_bar = (torch.exp(dt * A) - 1.0) / A elif self.discretization == "dirac": A_bar = torch.exp(dtA * A) gamma_bar = torch.ones_like(dt, dtype=A.dtype) else: raise ValueError( f"Unknown discretization: {self.discretization}" ) Bu = torch.einsum("ph,bh->bp", B_tilde, x.to(B_tilde.dtype)) state.copy_(A_bar * state + gamma_bar * Bu) y_complex = torch.einsum("hp,bp->bh", C_tilde, state) y = 2.0 * y_complex.real if self.conj_sym else y_complex.real # Skip connection y = y + self.D * x inference_cache["seqlen_offset"] += 1 return y.unsqueeze(1).to(x.dtype), inference_cache
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int, dtype: Optional[torch.dtype] = None, **kwargs, ) -> Dict[str, Any]: """ Allocates cache for S5 autoregressive inference. Stores the **continuous-time** parameters so that simplified_state_update can fuse discretization into the kernel. Args: batch_size (int): The batch size for inference. max_seqlen (int): Maximum sequence length (unused, for interface consistency). dtype (torch.dtype, optional): Data type for allocated tensors. Defaults to None. **kwargs: Additional keyword arguments. Returns: Dict[str, Any]: Cache dictionary containing SSM state and continuous-time matrices. """ device = self.A.device P = self.d_state A = torch.complex(self.A[..., 0], self.A[..., 1]) # (P,) B_tilde = torch.complex(self.B[..., 0], self.B[..., 1]) # (P, H) C_tilde = torch.complex(self.C[..., 0], self.C[..., 1]) # (H, P) dt = self.log_dt.exp() * self.step_rescale # (P,) return { "ssm_state": torch.zeros( batch_size, P, dtype=torch.complex64, device=device ), "A": A, "B_tilde": B_tilde, "C_tilde": C_tilde, "dt": dt, "seqlen_offset": 0, }