lrnnx.models.lti.s5 module¶
Basic S5 SSM. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
- class S5[source]¶
Bases:
LTI_LRNNBasic S5 State Space Model. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
Example
>>> model = S5(d_model=64, d_state=64, discretization="zoh") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'], conj_sym: bool = False)[source]¶
Initialize S5 model.
- Parameters:
d_model (int) – Model dimension.
d_state (int) – State dimension (P in the original paper).
discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.
conj_sym (bool, optional) – If True, uses conjugate symmetry for the state space model. Defaults to False.
- discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]¶
Discretizes the continuous-time system matrices A and B using the specified discretization method.
- Returns:
- A tuple containing:
A_bar : Discretized system matrix A, shape
(N,).gamma_bar : Input normalizer, shape
(N,)or a float.C_complex : Complex output matrix C, shape
(H, N).
- Return type:
tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]
- compute_kernel(L: int, A_bar: torch.Tensor, gamma_bar: torch.Tensor | float)[source]¶
Computes the kernel matrices for the S5 model: A^t and B_bar.
- Parameters:
L (int) – Length of the input sequence.
A_bar (torch.Tensor) – Discretized system matrix A, shape
(N,).gamma_bar (Union[torch.Tensor, float]) – Input normalizer, shape
(N,)or a float.
- Returns:
- A tuple containing:
A_power : Power of the discretized system matrix A, shape
(N, L).B_bar : Normalized input projection matrix, shape
(N, H).
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass of the S5 SSM using FFT-based convolution.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, L, H).integration_timesteps (torch.Tensor, optional) – Not used by S5 (LTI model). Kept for interface compatibility with LTV models. Defaults to None.
lengths (torch.Tensor, optional) – Lengths of the input sequences, shape
(B,). TODO: Support bidirectional models. Defaults to None.
- Returns:
Output tensor of shape
(B, L, H).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Performs a single recurrent step of the S5 model.
- Parameters:
x (torch.Tensor) – Input at current time step, shape
(B, H).inference_cache (Dict[str, Any]) – Cache from
allocate_inference_cache()containing “lrnn_state” and pre-computed matrices.
- Returns:
- Output y_t of shape
(B, H) and updated cache dictionary.
- Output y_t of shape
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]¶
Allocates cache for inference.
- Parameters:
batch_size (int) – The batch size for the input data.
max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.
- Returns:
- Cache dict with “lrnn_state” and
pre-computed discrete matrices.
- Return type:
Dict[str, Any]