lrnnx.models.lti.s4d module

Taken from the original S4 implementation and modified to fit into the LRNNX framework. https://github.com/state-spaces/s4

class S4D[source]

Bases: LTI_LRNN

General block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.

Other options are all experimental and should not need to be configured.

Example

>>> model = S4D(d_model=64, d_state=64, l_max=1024)
>>> x = torch.randn(2, 1024, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 1024, 64])
__init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True, disc='zoh')[source]

Initialize S4D block.

Parameters:
  • d_model (int) – Model dimension.

  • bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.

  • gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.

  • final_act (str, optional) – Activation after final linear layer. 'id' for no activation, None for no linear layer at all. Defaults to "glu".

  • postact (str, optional) – Deprecated, use final_act. Defaults to None.

  • dropout (float, optional) – Standard dropout argument. Defaults to 0.0.

  • tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating nn.Dropout1d. Defaults to False.

  • transposed (bool, optional) – Backbone axis ordering (B, L, H) (False) or (B, H, L) (True). Defaults to True.

  • l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.

  • channels (int, optional) – Number of channels/heads. Defaults to 1.

  • d_state (int, optional) – State dimension (N). Defaults to 64.

  • dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.

  • dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.

  • dt_tie (bool, optional) – Tie dt across channels. Defaults to True.

  • dt_transform (str, optional) – Transformation to apply to dt. Defaults to "exp".

  • dt_fast (bool, optional) – Fast dt initialization. Defaults to False.

  • rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.

  • n_ssm (int, optional) – Number of independent SSMs. Defaults to None.

  • init (str, optional) – Initialization method for the A matrix (e.g., "legs"). Defaults to "legs".

  • deterministic (bool, optional) – Use deterministic initialization. Defaults to False.

  • real_transform (str, optional) – Transformation for the real part of A. Defaults to "exp".

  • imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to "none".

  • is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.

  • lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.

  • wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.

  • verbose (bool, optional) – Print initialization information. Defaults to True.

  • disc (str, optional) – S4D-specific discretization method. Defaults to "zoh".

forward(x, lengths=None)[source]

Forward pass of the S4D block.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H, L) if self.transposed else (B, L, H).

  • lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.

Returns:

A tuple containing:
  • y : Output tensor of the same shape as x.

  • state : The next recurrent state, or None.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

step(x: torch.Tensor, inference_cache: dict) tuple[source]

Perform a single recurrent step of the S4D model.

Parameters:
  • x (torch.Tensor) – Input at current timestep, shape (B, H).

  • inference_cache (dict) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y_t : Output tensor at the current timestep of shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, dict]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]

Allocate cache for step-by-step inference.

Calls setup_step() to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.

Parameters:
  • batch_size (int) – Batch size for inference.

  • max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.

  • dtype (torch.dtype, optional) – Unused. Defaults to None.

Returns:

Cache dict with “lrnn_state” key.

Return type:

dict

default_state(*batch_shape, device=None)[source]
property d_output