lrnnx.models.lti.s4d module¶
Taken from the original S4 implementation and modified to fit into the LRNNX framework. https://github.com/state-spaces/s4
- class S4D[source]¶
Bases:
LTI_LRNNGeneral block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others.
Other options are all experimental and should not need to be configured.
Example
>>> model = S4D(d_model=64, d_state=64, l_max=1024) >>> x = torch.randn(2, 1024, 64) >>> y = model(x) >>> y.shape torch.Size([2, 1024, 64])
- __init__(d_model, bottleneck=None, gate=None, final_act='glu', postact=None, dropout=0.0, tie_dropout=False, transposed=True, l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', dt_fast=False, rank=1, n_ssm=None, init='legs', deterministic=False, real_transform='exp', imag_transform='none', is_real=False, lr=None, wd=0.0, verbose=True, disc='zoh')[source]¶
Initialize S4D block.
- Parameters:
d_model (int) – Model dimension.
bottleneck (int, optional) – Reduce dimension of inner layer (e.g. used in GSS). Defaults to None.
gate (int, optional) – Add multiplicative gating (e.g. used in GSS). Defaults to None.
final_act (str, optional) – Activation after final linear layer.
'id'for no activation,Nonefor no linear layer at all. Defaults to"glu".postact (str, optional) – Deprecated, use final_act. Defaults to None.
dropout (float, optional) – Standard dropout argument. Defaults to 0.0.
tie_dropout (bool, optional) – Tie dropout mask across sequence length, emulating
nn.Dropout1d. Defaults to False.transposed (bool, optional) – Backbone axis ordering
(B, L, H)(False) or(B, H, L)(True). Defaults to True.l_max (int, optional) – Maximum sequence length for the kernel. Defaults to None.
channels (int, optional) – Number of channels/heads. Defaults to 1.
d_state (int, optional) – State dimension (N). Defaults to 64.
dt_min (float, optional) – Minimum value for dt initialization. Defaults to 0.001.
dt_max (float, optional) – Maximum value for dt initialization. Defaults to 0.1.
dt_tie (bool, optional) – Tie dt across channels. Defaults to True.
dt_transform (str, optional) – Transformation to apply to dt. Defaults to
"exp".dt_fast (bool, optional) – Fast dt initialization. Defaults to False.
rank (int, optional) – Rank of the low-rank correction for DPLR. Defaults to 1.
n_ssm (int, optional) – Number of independent SSMs. Defaults to None.
init (str, optional) – Initialization method for the A matrix (e.g.,
"legs"). Defaults to"legs".deterministic (bool, optional) – Use deterministic initialization. Defaults to False.
real_transform (str, optional) – Transformation for the real part of A. Defaults to
"exp".imag_transform (str, optional) – Transformation for the imaginary part of A. Defaults to
"none".is_real (bool, optional) – Whether to use real-valued SSMs. Defaults to False.
lr (float, optional) – Specific learning rate for SSM parameters. Defaults to None.
wd (float, optional) – Specific weight decay for SSM parameters. Defaults to 0.0.
verbose (bool, optional) – Print initialization information. Defaults to True.
disc (str, optional) – S4D-specific discretization method. Defaults to
"zoh".
- forward(x, lengths=None)[source]¶
Forward pass of the S4D block.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, H, L)ifself.transposedelse(B, L, H).lengths (torch.Tensor | int, optional) – Lengths of the sequences in the batch for padding masking. Defaults to None.
- Returns:
- A tuple containing:
y : Output tensor of the same shape as x.
state : The next recurrent state, or None.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- step(x: torch.Tensor, inference_cache: dict) tuple[source]¶
Perform a single recurrent step of the S4D model.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape
(B, H).inference_cache (dict) – Cache from
allocate_inference_cache().
- Returns:
- A tuple containing:
y_t : Output tensor at the current timestep of shape
(B, H).inference_cache : Updated cache dictionary.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) dict[source]¶
Allocate cache for step-by-step inference.
Calls
setup_step()to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state.- Parameters:
batch_size (int) – Batch size for inference.
max_seqlen (int, optional) – Unused, kept for interface consistency. Defaults to 1.
dtype (torch.dtype, optional) – Unused. Defaults to None.
- Returns:
Cache dict with “lrnn_state” key.
- Return type:
- property d_output¶