lrnnx.models.lti package

Linear Time-Invariant (LTI) LRNN models.

class LTI_LRNN[source]

Bases: LRNN

Base class for all LTI LRNN models.

Note

LTI models do not support async discretization as that requires time-varying dynamics. For async/event-driven models, use LTV models.

Example

>>> from lrnnx.models.lti import LTI_LRNN
>>> my_lrnn = LTI_LRNN("zoh")
>>> # create dummy input tensor and perform forward pass
>>> # in subclass
__init__(discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'])[source]

Initialize the LTI LRNN base class.

Parameters:

discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.

abstractmethod discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]

This function discretizes the A, B and C matrices, with a learned step-size delta. This could be done inside the compute_kernel method itself, but doing this explicitly outside allows for more flexibility later.

Returns:

A tuple of tensors representing the

discretized A, B, C matrices, ideally of shapes (B, N), (B, N, H) or float, and (B, H, N) respectively.

Return type:

tuple[torch.Tensor, torch.Tensor | float, torch.Tensor]

abstractmethod compute_kernel() tuple[torch.Tensor, torch.Tensor][source]

Computes the convolution kernel for efficient parallel processing.

This function is only relevant for LTI models; for LTV models this will materialize a huge vector in-memory at every timestep, which is not efficient. Reference: https://github.com/kunibald413/aTENNuate/blob/15a27dab00d3bf2c27cbbbc3bd41a3d9196dca1e/attenuate/model.py#L30

Parameters:

*args – Model-specific arguments (e.g., sequence length, discretized matrices). See subclass implementations for details.

Returns:

A tuple containing:
  • K : Powers of A matrix (A^0, A^1, …, A^{L-1}), shape (N, L)

  • B_norm : Normalized input projection matrix, shape (N, H)

Return type:

tuple[torch.Tensor, torch.Tensor]

abstractmethod step(x: torch.Tensor, inference_cache: Dict[str, Any]) Tuple[torch.Tensor, Dict[str, Any]][source]

Performs a single recurrent step of the LTI model.

This method is used for autoregressive inference, where inputs are processed one timestep at a time.

Parameters:
  • x (torch.Tensor) – Input at current timestep, shape (B, H).

  • inference_cache (Dict[str, Any]) – Cache dictionary from allocate_inference_cache() containing recurrent state and pre-computed matrices. Updated in-place and returned.

Returns:

A tuple containing:
  • y : Output at current timestep, shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

abstractmethod allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype: torch.dtype | None = None) Dict[str, Any][source]

Allocates initial state and caches matrices for efficient inference.

For LTI models, the system matrices (A, B, C) are time-invariant, so they can be pre-computed once and reused for all timesteps during autoregressive generation.

Parameters:
  • batch_size (int) – The batch size for inference.

  • max_seqlen (int, optional) – Maximum sequence length (unused for LTI, kept for interface consistency with LTV models). Defaults to 1.

  • dtype (torch.dtype, optional) – Data type for allocated tensors. Defaults to None.

Returns:

Cache dictionary containing initial state and

pre-computed matrices for use in step().

Return type:

Dict[str, Any]

class Centaurus(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh', mode: Literal['neck', 'pointwise', 'pw', 's5', 'dws', 'full'] = 'neck')[source]

Bases: object

Backwards-compatible wrapper that returns a mode-specific class instance.

Example

>>> model = Centaurus(d_model=64, d_state=64, sub_state_dim=8, mode="neck")
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusDWS[source]

Bases: CentaurusBase

Depthwise-separable block with one state per channel.

Example

>>> model = CentaurusDWS(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusFull[source]

Bases: CentaurusBase

Fully connected block with a state per (in, out) pair.

Example

>>> model = CentaurusFull(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusNeck[source]

Bases: CentaurusBase

Bottleneck block with dense in/out projections.

Example

>>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusPWNeck[source]

Bases: CentaurusBase

Pointwise bottleneck (s5 in public implementations) that flattens (N, M) -> (N*M).

This variant removes E-mixing and repeats delta over M sub-states per state, yielding independent SISO lanes aggregated in a single flattened axis.

Example

>>> model = CentaurusPWNeck(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
compute_kernel() tuple[torch.Tensor, torch.Tensor][source]

Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.

Returns:

A tuple containing:
  • k : Latent kernel of shape (N, L), where N is the number of state channels.

  • empty : Placeholder for compatibility with standard LTI interface expectations.

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass through a Centaurus LTI mode variant.

Parameters:
  • x (torch.Tensor) – Input sequence of shape (B, L, H_in).

  • integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.

  • lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.

Returns:

Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.

Return type:

torch.Tensor

allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]

Allocate initial streaming state and cache matrices.

Parameters:
  • batch_size (int) – The batch size.

  • max_seqlen (int, optional) – Maximum sequence length. Defaults to 1.

Returns:

Cache dict with initial state and precomputed discrete parameters.

Return type:

Dict[str, Any]

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Single-timestep streaming update for a Centaurus variant.

This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y : Output tensor of shape (B, H_out) (real).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

class LRU[source]

Bases: LTI_LRNN

Linear Recurrent Unit (LRU) layer.

Paper: https://arxiv.org/abs/2303.06349

Example

>>> model = LRU(d_model=64, d_state=64)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, r_min: float = 0, r_max: float = 1, max_phase: float = 6.283185307179586) None[source]

Initialize LRU layer.

Parameters:
  • d_model (int) – Model dimension.

  • d_state (int) – State dimension.

  • r_min (float, optional) – Minimum radius for Lambda initialization. Defaults to 0.

  • r_max (float, optional) – Maximum radius for Lambda initialization. Defaults to 1.

  • max_phase (float, optional) – Maximum phase for Lambda initialization. Defaults to 2 * math.pi.

discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

LRU uses no_discretization, so this acts like a prepare matrices method.

Returns:

A tuple containing:
  • A : Diagonal matrix of Lambda values, shape (N, N).

  • B : Complex input projection matrix, shape (N, H).

  • C : Complex output projection matrix, shape (H, N).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

compute_kernel(L: int, Lambda: torch.Tensor, B_complex: torch.Tensor) tuple[torch.Tensor, torch.Tensor][source]

Compute Lambda and normalized B matrix for LRU.

Parameters:
  • L (int) – Length of the input sequence.

  • Lambda (torch.Tensor) – Complex eigenvalues/diagonal elements, shape (N,).

  • B_complex (torch.Tensor) – Complex input projection matrix, shape (N, H).

Returns:

A tuple containing:
  • Lambda : Complex eigenvalues/diagonal elements, shape (N,).

  • B_norm : Normalized complex input projection matrix, shape (N, H).

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass of the LRU layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, L, H).

  • integration_timesteps (torch.Tensor, optional) – <To be implemented>. Defaults to None.

  • lengths (torch.Tensor, optional) – <To be implemented>. Defaults to None.

Returns:

Output tensor of shape (B, L, H).

Return type:

torch.Tensor

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Single step inference for LRU layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H) - single timestep.

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache() containing “lrnn_state” and pre-computed matrices.

Returns:

A tuple containing:
  • y : Output tensor of shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]

Allocate initial state and cached matrices for inference.

Parameters:
  • batch_size (int) – Batch size.

  • max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.

  • dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.

Returns:

Cache dict with “lrnn_state” and

pre-computed discrete matrices.

Return type:

Dict[str, Any]

class S4[source]

Bases: LTI_LRNN

General block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.

Other options are all experimental and should not need to be configured.

Example

>>> model = S4(d_model=64, d_state=64, l_max=1024)
>>> x = torch.randn(2, 1024, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 1024, 64])
__init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True)[source]

Initialize S4 block.

Parameters:
  • d_model (int) – Model dimension.

  • bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.

  • gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.

  • final_act (str, optional) – Activation after final linear layer. 'id' for no activation, None for no linear layer at all. Defaults to "glu".

  • postact (str, optional) – Deprecated, use final_act. Defaults to None.

  • dropout (float, optional) – Standard dropout argument. Defaults to 0.0.

  • tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating nn.Dropout1d. Defaults to False.

  • transposed (bool, optional) – Backbone axis ordering (B, L, H) (False) or (B, H, L) (True). Defaults to True.

  • l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.

  • channels (int, optional) – Number of channels/heads. Defaults to 1.

  • d_state (int, optional) – State dimension (N). Defaults to 64.

  • dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.

  • dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.

  • dt_tie (bool, optional) – Tie dt across channels. Defaults to True.

  • dt_transform (str, optional) – Transformation to apply to dt. Defaults to "exp".

  • dt_fast (bool, optional) – Fast dt initialization. Defaults to False.

  • rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.

  • n_ssm (int, optional) – Number of independent SSMs. Defaults to None.

  • init (str, optional) – Initialization method for the A matrix (e.g., "legs"). Defaults to "legs".

  • deterministic (bool, optional) – Use deterministic initialization. Defaults to False.

  • real_transform (str, optional) – Transformation for the real part of A. Defaults to "exp".

  • imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to "none".

  • is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.

  • lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.

  • wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.

  • verbose (bool, optional) – Print initialization information. Defaults to True.

forward(x, lengths=None)[source]

Forward pass of the S4 block.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H, L) if self.transposed else (B, L, H).

  • lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.

Returns:

A tuple containing:
  • y : Output tensor of the same shape as x.

  • state : The next recurrent state, or None.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

step(x: torch.Tensor, inference_cache: dict) tuple[source]

Perform a single recurrent step of the S4 model.

Parameters:
  • x (torch.Tensor) – Input at current timestep, shape (B, H).

  • inference_cache (dict) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y_t : Output tensor at the current timestep of shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, dict]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]

Allocate cache for step-by-step inference.

Calls setup_step() to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.

Parameters:
  • batch_size (int) – Batch size for inference.

  • max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.

  • dtype (torch.dtype, optional) – Unused. Defaults to None.

Returns:

Cache dict with “lrnn_state” key.

Return type:

dict

default_state(*batch_shape, device=None)[source]
property d_output
class S4D[source]

Bases: LTI_LRNN

General block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.

Other options are all experimental and should not need to be configured.

Example

>>> model = S4D(d_model=64, d_state=64, l_max=1024)
>>> x = torch.randn(2, 1024, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 1024, 64])
__init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True, disc='zoh')[source]

Initialize S4D block.

Parameters:
  • d_model (int) – Model dimension.

  • bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.

  • gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.

  • final_act (str, optional) – Activation after final linear layer. 'id' for no activation, None for no linear layer at all. Defaults to "glu".

  • postact (str, optional) – Deprecated, use final_act. Defaults to None.

  • dropout (float, optional) – Standard dropout argument. Defaults to 0.0.

  • tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating nn.Dropout1d. Defaults to False.

  • transposed (bool, optional) – Backbone axis ordering (B, L, H) (False) or (B, H, L) (True). Defaults to True.

  • l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.

  • channels (int, optional) – Number of channels/heads. Defaults to 1.

  • d_state (int, optional) – State dimension (N). Defaults to 64.

  • dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.

  • dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.

  • dt_tie (bool, optional) – Tie dt across channels. Defaults to True.

  • dt_transform (str, optional) – Transformation to apply to dt. Defaults to "exp".

  • dt_fast (bool, optional) – Fast dt initialization. Defaults to False.

  • rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.

  • n_ssm (int, optional) – Number of independent SSMs. Defaults to None.

  • init (str, optional) – Initialization method for the A matrix (e.g., "legs"). Defaults to "legs".

  • deterministic (bool, optional) – Use deterministic initialization. Defaults to False.

  • real_transform (str, optional) – Transformation for the real part of A. Defaults to "exp".

  • imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to "none".

  • is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.

  • lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.

  • wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.

  • verbose (bool, optional) – Print initialization information. Defaults to True.

  • disc (str, optional) – S4D-specific discretization method. Defaults to "zoh".

forward(x, lengths=None)[source]

Forward pass of the S4D block.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H, L) if self.transposed else (B, L, H).

  • lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.

Returns:

A tuple containing:
  • y : Output tensor of the same shape as x.

  • state : The next recurrent state, or None.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

step(x: torch.Tensor, inference_cache: dict) tuple[source]

Perform a single recurrent step of the S4D model.

Parameters:
  • x (torch.Tensor) – Input at current timestep, shape (B, H).

  • inference_cache (dict) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y_t : Output tensor at the current timestep of shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, dict]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]

Allocate cache for step-by-step inference.

Calls setup_step() to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.

Parameters:
  • batch_size (int) – Batch size for inference.

  • max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.

  • dtype (torch.dtype, optional) – Unused. Defaults to None.

Returns:

Cache dict with “lrnn_state” key.

Return type:

dict

default_state(*batch_shape, device=None)[source]
property d_output
class S5[source]

Bases: LTI_LRNN

Basic S5 State Space Model. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks

Example

>>> model = S5(d_model=64, d_state=64, discretization="zoh")
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'], conj_sym: bool = False)[source]

Initialize S5 model.

Parameters:
  • d_model (int) – Model dimension.

  • d_state (int) – State dimension (P in the original paper).

  • discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.

  • conj_sym (bool, optional) – If True, uses conjugate symmetry for the state space model. Defaults to False.

discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]

Discretizes the continuous-time system matrices A and B using the specified discretization method.

Returns:

A tuple containing:
  • A_bar : Discretized system matrix A, shape (N,).

  • gamma_bar : Input normalizer, shape (N,) or a float.

  • C_complex : Complex output matrix C, shape (H, N).

Return type:

tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]

compute_kernel(L: int, A_bar: torch.Tensor, gamma_bar: torch.Tensor | float)[source]

Computes the kernel matrices for the S5 model: A^t and B_bar.

Parameters:
  • L (int) – Length of the input sequence.

  • A_bar (torch.Tensor) – Discretized system matrix A, shape (N,).

  • gamma_bar (Union[torch.Tensor, float]) – Input normalizer, shape (N,) or a float.

Returns:

A tuple containing:
  • A_power : Power of the discretized system matrix A, shape (N, L).

  • B_bar : Normalized input projection matrix, shape (N, H).

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass of the S5 SSM using FFT-based convolution.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, L, H).

  • integration_timesteps (torch.Tensor, optional) – Not used by S5 (LTI model). Kept for interface compatibility with LTV models. Defaults to None.

  • lengths (torch.Tensor, optional) – Lengths of the input sequences, shape (B,). TODO: Support bidirectional models. Defaults to None.

Returns:

Output tensor of shape (B, L, H).

Return type:

torch.Tensor

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Performs a single recurrent step of the S5 model.

Parameters:
  • x (torch.Tensor) – Input at current time step, shape (B, H).

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache() containing “lrnn_state” and pre-computed matrices.

Returns:

Output y_t of shape (B, H)

and updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]

Allocates cache for inference.

Parameters:
  • batch_size (int) – The batch size for the input data.

  • max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.

  • dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.

Returns:

Cache dict with “lrnn_state” and

pre-computed discrete matrices.

Return type:

Dict[str, Any]

Submodules