lrnnx.models.lti package¶
Linear Time-Invariant (LTI) LRNN models.
- class LTI_LRNN[source]¶
Bases:
LRNNBase class for all LTI LRNN models.
Note
LTI models do not support async discretization as that requires time-varying dynamics. For async/event-driven models, use LTV models.
Example
>>> from lrnnx.models.lti import LTI_LRNN >>> my_lrnn = LTI_LRNN("zoh") >>> # create dummy input tensor and perform forward pass >>> # in subclass
- __init__(discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'])[source]¶
Initialize the LTI LRNN base class.
- Parameters:
discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.
- abstractmethod discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]¶
This function discretizes the A, B and C matrices, with a learned step-size delta. This could be done inside the compute_kernel method itself, but doing this explicitly outside allows for more flexibility later.
- Returns:
- A tuple of tensors representing the
discretized A, B, C matrices, ideally of shapes (B, N), (B, N, H) or float, and (B, H, N) respectively.
- Return type:
- abstractmethod compute_kernel() tuple[torch.Tensor, torch.Tensor][source]¶
Computes the convolution kernel for efficient parallel processing.
This function is only relevant for LTI models; for LTV models this will materialize a huge vector in-memory at every timestep, which is not efficient. Reference: https://github.com/kunibald413/aTENNuate/blob/15a27dab00d3bf2c27cbbbc3bd41a3d9196dca1e/attenuate/model.py#L30
- Parameters:
*args – Model-specific arguments (e.g., sequence length, discretized matrices). See subclass implementations for details.
- Returns:
- A tuple containing:
K : Powers of A matrix (A^0, A^1, …, A^{L-1}), shape (N, L)
B_norm : Normalized input projection matrix, shape (N, H)
- Return type:
- abstractmethod step(x: torch.Tensor, inference_cache: Dict[str, Any]) Tuple[torch.Tensor, Dict[str, Any]][source]¶
Performs a single recurrent step of the LTI model.
This method is used for autoregressive inference, where inputs are processed one timestep at a time.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape (B, H).
inference_cache (Dict[str, Any]) – Cache dictionary from allocate_inference_cache() containing recurrent state and pre-computed matrices. Updated in-place and returned.
- Returns:
- A tuple containing:
y : Output at current timestep, shape (B, H).
inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- abstractmethod allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype: torch.dtype | None = None) Dict[str, Any][source]¶
Allocates initial state and caches matrices for efficient inference.
For LTI models, the system matrices (A, B, C) are time-invariant, so they can be pre-computed once and reused for all timesteps during autoregressive generation.
- Parameters:
batch_size (int) – The batch size for inference.
max_seqlen (int, optional) – Maximum sequence length (unused for LTI, kept for interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional) – Data type for allocated tensors. Defaults to None.
- Returns:
- Cache dictionary containing initial state and
pre-computed matrices for use in step().
- Return type:
Dict[str, Any]
- class Centaurus(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh', mode: Literal['neck', 'pointwise', 'pw', 's5', 'dws', 'full'] = 'neck')[source]¶
Bases:
objectBackwards-compatible wrapper that returns a mode-specific class instance.
Example
>>> model = Centaurus(d_model=64, d_state=64, sub_state_dim=8, mode="neck") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusDWS[source]¶
Bases:
CentaurusBaseDepthwise-separable block with one state per channel.
Example
>>> model = CentaurusDWS(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusFull[source]¶
Bases:
CentaurusBaseFully connected block with a state per (in, out) pair.
Example
>>> model = CentaurusFull(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusNeck[source]¶
Bases:
CentaurusBaseBottleneck block with dense in/out projections.
Example
>>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusPWNeck[source]¶
Bases:
CentaurusBasePointwise bottleneck (s5 in public implementations) that flattens (N, M) -> (N*M).
This variant removes E-mixing and repeats delta over M sub-states per state, yielding independent SISO lanes aggregated in a single flattened axis.
Example
>>> model = CentaurusPWNeck(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- compute_kernel() tuple[torch.Tensor, torch.Tensor][source]¶
Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.
- Returns:
- A tuple containing:
k : Latent kernel of shape (N, L), where N is the number of state channels.
empty : Placeholder for compatibility with standard LTI interface expectations.
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass through a Centaurus LTI mode variant.
- Parameters:
x (torch.Tensor) – Input sequence of shape (B, L, H_in).
integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.
lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.
- Returns:
Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]¶
Allocate initial streaming state and cache matrices.
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Single-timestep streaming update for a Centaurus variant.
This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.
- Parameters:
x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.
inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().
- Returns:
- A tuple containing:
y : Output tensor of shape
(B, H_out)(real).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- class LRU[source]¶
Bases:
LTI_LRNNLinear Recurrent Unit (LRU) layer.
Paper: https://arxiv.org/abs/2303.06349
Example
>>> model = LRU(d_model=64, d_state=64) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, r_min: float = 0, r_max: float = 1, max_phase: float = 6.283185307179586) None[source]¶
Initialize LRU layer.
- Parameters:
d_model (int) – Model dimension.
d_state (int) – State dimension.
r_min (float, optional) – Minimum radius for Lambda initialization. Defaults to 0.
r_max (float, optional) – Maximum radius for Lambda initialization. Defaults to 1.
max_phase (float, optional) – Maximum phase for Lambda initialization. Defaults to
2 * math.pi.
- discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
LRU uses no_discretization, so this acts like a prepare matrices method.
- Returns:
- A tuple containing:
A : Diagonal matrix of Lambda values, shape
(N, N).B : Complex input projection matrix, shape
(N, H).C : Complex output projection matrix, shape
(H, N).
- Return type:
- compute_kernel(L: int, Lambda: torch.Tensor, B_complex: torch.Tensor) tuple[torch.Tensor, torch.Tensor][source]¶
Compute Lambda and normalized B matrix for LRU.
- Parameters:
L (int) – Length of the input sequence.
Lambda (torch.Tensor) – Complex eigenvalues/diagonal elements, shape
(N,).B_complex (torch.Tensor) – Complex input projection matrix, shape
(N, H).
- Returns:
- A tuple containing:
Lambda : Complex eigenvalues/diagonal elements, shape
(N,).B_norm : Normalized complex input projection matrix, shape
(N, H).
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass of the LRU layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, L, H).integration_timesteps (torch.Tensor, optional) – <To be implemented>. Defaults to None.
lengths (torch.Tensor, optional) – <To be implemented>. Defaults to None.
- Returns:
Output tensor of shape
(B, L, H).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Single step inference for LRU layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, H)- single timestep.inference_cache (Dict[str, Any]) – Cache from
allocate_inference_cache()containing “lrnn_state” and pre-computed matrices.
- Returns:
- A tuple containing:
y : Output tensor of shape
(B, H).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]¶
Allocate initial state and cached matrices for inference.
- Parameters:
batch_size (int) – Batch size.
max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.
- Returns:
- Cache dict with “lrnn_state” and
pre-computed discrete matrices.
- Return type:
Dict[str, Any]
- class S4[source]¶
Bases:
LTI_LRNNGeneral block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.
Other options are all experimental and should not need to be configured.
Example
>>> model = S4(d_model=64, d_state=64, l_max=1024) >>> x = torch.randn(2, 1024, 64) >>> y = model(x) >>> y.shape torch.Size([2, 1024, 64])
- __init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True)[source]¶
Initialize S4 block.
- Parameters:
d_model (int) – Model dimension.
bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.
gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.
final_act (str, optional) – Activation after final linear layer.
'id'for no activation,Nonefor no linear layer at all. Defaults to"glu".postact (str, optional) – Deprecated, use final_act. Defaults to None.
dropout (float, optional) – Standard dropout argument. Defaults to 0.0.
tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating
nn.Dropout1d. Defaults to False.transposed (bool, optional) – Backbone axis ordering
(B, L, H)(False) or(B, H, L)(True). Defaults to True.l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.
channels (int, optional) – Number of channels/heads. Defaults to 1.
d_state (int, optional) – State dimension (N). Defaults to 64.
dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.
dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.
dt_tie (bool, optional) – Tie dt across channels. Defaults to True.
dt_transform (str, optional) – Transformation to apply to dt. Defaults to
"exp".dt_fast (bool, optional) – Fast dt initialization. Defaults to False.
rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.
n_ssm (int, optional) – Number of independent SSMs. Defaults to None.
init (str, optional) – Initialization method for the A matrix (e.g.,
"legs"). Defaults to"legs".deterministic (bool, optional) – Use deterministic initialization. Defaults to False.
real_transform (str, optional) – Transformation for the real part of A. Defaults to
"exp".imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to
"none".is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.
lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.
wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.
verbose (bool, optional) – Print initialization information. Defaults to True.
- forward(x, lengths=None)[source]¶
Forward pass of the S4 block.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, H, L)ifself.transposedelse(B, L, H).lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.
- Returns:
- A tuple containing:
y : Output tensor of the same shape as x.
state : The next recurrent state, or None.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- step(x: torch.Tensor, inference_cache: dict) tuple[source]¶
Perform a single recurrent step of the S4 model.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape
(B, H).inference_cache (dict) – Cache from
allocate_inference_cache().
- Returns:
- A tuple containing:
y_t : Output tensor at the current timestep of shape
(B, H).inference_cache : Updated cache dictionary.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]¶
Allocate cache for step-by-step inference.
Calls
setup_step()to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.- Parameters:
batch_size (int) – Batch size for inference.
max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.
dtype (torch.dtype, optional) – Unused. Defaults to None.
- Returns:
Cache dict with “lrnn_state” key.
- Return type:
- property d_output¶
- class S4D[source]¶
Bases:
LTI_LRNNGeneral block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.
Other options are all experimental and should not need to be configured.
Example
>>> model = S4D(d_model=64, d_state=64, l_max=1024) >>> x = torch.randn(2, 1024, 64) >>> y = model(x) >>> y.shape torch.Size([2, 1024, 64])
- __init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True, disc='zoh')[source]¶
Initialize S4D block.
- Parameters:
d_model (int) – Model dimension.
bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.
gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.
final_act (str, optional) – Activation after final linear layer.
'id'for no activation,Nonefor no linear layer at all. Defaults to"glu".postact (str, optional) – Deprecated, use final_act. Defaults to None.
dropout (float, optional) – Standard dropout argument. Defaults to 0.0.
tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating
nn.Dropout1d. Defaults to False.transposed (bool, optional) – Backbone axis ordering
(B, L, H)(False) or(B, H, L)(True). Defaults to True.l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.
channels (int, optional) – Number of channels/heads. Defaults to 1.
d_state (int, optional) – State dimension (N). Defaults to 64.
dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.
dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.
dt_tie (bool, optional) – Tie dt across channels. Defaults to True.
dt_transform (str, optional) – Transformation to apply to dt. Defaults to
"exp".dt_fast (bool, optional) – Fast dt initialization. Defaults to False.
rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.
n_ssm (int, optional) – Number of independent SSMs. Defaults to None.
init (str, optional) – Initialization method for the A matrix (e.g.,
"legs"). Defaults to"legs".deterministic (bool, optional) – Use deterministic initialization. Defaults to False.
real_transform (str, optional) – Transformation for the real part of A. Defaults to
"exp".imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to
"none".is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.
lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.
wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.
verbose (bool, optional) – Print initialization information. Defaults to True.
disc (str, optional) – S4D-specific discretization method. Defaults to
"zoh".
- forward(x, lengths=None)[source]¶
Forward pass of the S4D block.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, H, L)ifself.transposedelse(B, L, H).lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.
- Returns:
- A tuple containing:
y : Output tensor of the same shape as x.
state : The next recurrent state, or None.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- step(x: torch.Tensor, inference_cache: dict) tuple[source]¶
Perform a single recurrent step of the S4D model.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape
(B, H).inference_cache (dict) – Cache from
allocate_inference_cache().
- Returns:
- A tuple containing:
y_t : Output tensor at the current timestep of shape
(B, H).inference_cache : Updated cache dictionary.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]¶
Allocate cache for step-by-step inference.
Calls
setup_step()to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.- Parameters:
batch_size (int) – Batch size for inference.
max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.
dtype (torch.dtype, optional) – Unused. Defaults to None.
- Returns:
Cache dict with “lrnn_state” key.
- Return type:
- property d_output¶
- class S5[source]¶
Bases:
LTI_LRNNBasic S5 State Space Model. Reference: https://openreview.net/forum?id=Ai8Hw3AXqks
Example
>>> model = S5(d_model=64, d_state=64, discretization="zoh") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'no_discretization'], conj_sym: bool = False)[source]¶
Initialize S5 model.
- Parameters:
d_model (int) – Model dimension.
d_state (int) – State dimension (P in the original paper).
discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]) – Discretization method to use.
conj_sym (bool, optional) – If True, uses conjugate symmetry for the state space model. Defaults to False.
- discretize() tuple[torch.Tensor, torch.Tensor | float, torch.Tensor][source]¶
Discretizes the continuous-time system matrices A and B using the specified discretization method.
- Returns:
- A tuple containing:
A_bar : Discretized system matrix A, shape
(N,).gamma_bar : Input normalizer, shape
(N,)or a float.C_complex : Complex output matrix C, shape
(H, N).
- Return type:
tuple[torch.Tensor, Union[torch.Tensor, float], torch.Tensor]
- compute_kernel(L: int, A_bar: torch.Tensor, gamma_bar: torch.Tensor | float)[source]¶
Computes the kernel matrices for the S5 model: A^t and B_bar.
- Parameters:
L (int) – Length of the input sequence.
A_bar (torch.Tensor) – Discretized system matrix A, shape
(N,).gamma_bar (Union[torch.Tensor, float]) – Input normalizer, shape
(N,)or a float.
- Returns:
- A tuple containing:
A_power : Power of the discretized system matrix A, shape
(N, L).B_bar : Normalized input projection matrix, shape
(N, H).
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass of the S5 SSM using FFT-based convolution.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, L, H).integration_timesteps (torch.Tensor, optional) – Not used by S5 (LTI model). Kept for interface compatibility with LTV models. Defaults to None.
lengths (torch.Tensor, optional) – Lengths of the input sequences, shape
(B,). TODO: Support bidirectional models. Defaults to None.
- Returns:
Output tensor of shape
(B, L, H).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Performs a single recurrent step of the S5 model.
- Parameters:
x (torch.Tensor) – Input at current time step, shape
(B, H).inference_cache (Dict[str, Any]) – Cache from
allocate_inference_cache()containing “lrnn_state” and pre-computed matrices.
- Returns:
- Output y_t of shape
(B, H) and updated cache dictionary.
- Output y_t of shape
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]¶
Allocates cache for inference.
- Parameters:
batch_size (int) – The batch size for the input data.
max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.
- Returns:
- Cache dict with “lrnn_state” and
pre-computed discrete matrices.
- Return type:
Dict[str, Any]