lrnnx.models.ltv.s5 module¶
S5 SSM with CUDA kernel acceleration. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
- class S5[source]¶
Bases:
LTV_LRNNS5 SSM with CUDA kernel acceleration. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
Example
>>> model = S5(d_model=64, d_state=64) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, discretization: Literal['bilinear', 'zoh', 'dirac'] = 'zoh', conj_sym: bool = False, dt_min: float = 0.001, dt_max: float = 0.1, step_rescale: float = 1.0, use_fast_path: bool = True, device=None, dtype=None)[source]¶
Initialize S5 model.
- Parameters:
d_model (int) – Model dimension.
d_state (int) – State dimension.
discretization (Literal["bilinear", "zoh", "dirac"], optional) – Discretization method. Defaults to
"zoh".conj_sym (bool, optional) – If True, uses conjugate symmetry for the state space model. Defaults to False.
dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.
dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.
step_rescale (float, optional) – Rescale factor for step size. Defaults to 1.0.
use_fast_path (bool, optional) – Whether to use fused CUDA kernels. Defaults to True.
device (torch.device, optional) – Device for parameters. Defaults to None.
dtype (torch.dtype, optional) – Data type for parameters. Defaults to None.
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None, inference_cache: Dict[str, Any] | None = None) torch.Tensor[source]¶
Forward pass through S5.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, L, H).integration_timesteps (torch.Tensor, optional) – Timesteps for async/event-driven discretization. Defaults to None.
lengths (torch.Tensor, optional) – Lengths of sequences, required for variable-length sequences. Defaults to None.
inference_cache (Dict[str, Any], optional) – Cache for autoregressive generation. Defaults to None.
- Returns:
Output tensor of shape
(B, L, H).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any], integration_timesteps: torch.Tensor | None = None) Tuple[torch.Tensor, Dict[str, Any]][source]¶
Performs a single recurrent step of S5.
When the simplified_state_update Triton kernel is available and the tensors live on CUDA, the state is updated in-place via the kernel (which also fuses discretization, input projection, and output projection into a single launch). Otherwise a pure-PyTorch fallback is used.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape
(B, 1, H)or(B, H).inference_cache (Dict[str, Any]) – Cache dictionary containing SSM state and continuous-time parameters.
integration_timesteps (torch.Tensor, optional) – Optional per-step integration timesteps for event/async mode, shape
(B,)or(B, 1). Defaults to None.
- Returns:
- A tuple containing:
y : Output tensor at the current timestep.
inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int, dtype: torch.dtype | None = None) Dict[str, Any][source]¶
Allocates cache for S5 autoregressive inference.
Stores the continuous-time parameters so that simplified_state_update can fuse discretization into the kernel.
- Parameters:
batch_size (int) – The batch size for inference.
max_seqlen (int) – Maximum sequence length (unused, for interface consistency).
dtype (torch.dtype, optional) – Data type for allocated tensors. Defaults to None.
- Returns:
Cache dictionary containing SSM state and continuous-time matrices.
- Return type:
Dict[str, Any]