lrnnx.models.lti.s5 module

Basic S5 SSM. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks

class S5[source]

Bases: LTI_LRNN

Basic S5 State Space Model. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks

Example

>>> model = S5(d_model=64, d_state=64, discretization="zoh")
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'], conj_sym: bool = False)[source]

Initialize S5 model.

Parameters:
  • d_model (int) – Model dimension.

  • d_state (int) – State dimension (P in the original paper).

  • discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.

  • conj_sym (bool, optional) – If True, uses conjugate symmetry for the state space model. Defaults to False.

discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]

Discretizes the continuous-time system matrices A and B using the specified discretization method.

Returns:

A tuple containing:
  • A_bar : Discretized system matrix A, shape (N,).

  • gamma_bar : Input normalizer, shape (N,) or a float.

  • C_complex : Complex output matrix C, shape (H, N).

Return type:

tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]

compute_kernel(L: int, A_bar: torch.Tensor, gamma_bar: torch.Tensor | float)[source]

Computes the kernel matrices for the S5 model: A^t and B_bar.

Parameters:
  • L (int) – Length of the input sequence.

  • A_bar (torch.Tensor) – Discretized system matrix A, shape (N,).

  • gamma_bar (Union[torch.Tensor, float]) – Input normalizer, shape (N,) or a float.

Returns:

A tuple containing:
  • A_power : Power of the discretized system matrix A, shape (N, L).

  • B_bar : Normalized input projection matrix, shape (N, H).

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass of the S5 SSM using FFT-based convolution.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, L, H).

  • integration_timesteps (torch.Tensor, optional) – Not used by S5 (LTI model). Kept for interface compatibility with LTV models. Defaults to None.

  • lengths (torch.Tensor, optional) – Lengths of the input sequences, shape (B,). TODO: Support bidirectional models. Defaults to None.

Returns:

Output tensor of shape (B, L, H).

Return type:

torch.Tensor

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Performs a single recurrent step of the S5 model.

Parameters:
  • x (torch.Tensor) – Input at current time step, shape (B, H).

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache() containing “lrnn_state” and pre-computed matrices.

Returns:

Output y_t of shape (B, H)

and updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]

Allocates cache for inference.

Parameters:
  • batch_size (int) – The batch size for the input data.

  • max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.

  • dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.

Returns:

Cache dict with “lrnn_state” and

pre-computed discrete matrices.

Return type:

Dict[str, Any]