lrnnx.ops.s4_kernel_interface module¶
- class S4KernelBase[source]¶
Bases:
ModuleBase class for S4 kernels - receives parameters from the parent model.
- Parameters:
d_model (int) – Model dimension.
l_max (int | None) – Maximum sequence length.
channels (int) – Number of channels/heads.
param_config (dict) –
A dictionary containing:
Parameter references: A_real, A_imag, B, C, inv_dt, P (nn.Parameters owned by S4/S4D)
Computed scalars: N, H, channels, rank, repeat
Config flags: dt_fast, real_transform, imag_transform, dt_transform, is_real, deterministic, verbose
S4D-only: disc
- class S4Kernel[source]¶
Bases:
S4KernelBaseSSM kernel for diagonal + low rank (DPLR) state matrices - pure convolution operation.
- Parameters:
- forward(state=None, rate=1.0, L=None)[source]¶
Compute SSM convolution kernel - the core operation.
- Parameters:
state (torch.Tensor, optional) – State tensor. Defaults to None.
rate (float, optional) – Sampling rate. Defaults to 1.0.
L (int, optional) – Sequence length. Defaults to None.
- Returns:
- A tuple containing:
k_B : Convolution kernel.
k_state : Kernel state, if state is provided.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- double_length()¶
Double the sequence length representation.
- default_state(*batch_shape)[source]¶
Create default state.
- Parameters:
*batch_shape – Variable length argument list for batch dimensions.
- Returns:
A zero-initialized state tensor.
- Return type:
- step(u, state)[source]¶
Perform single step.
- Parameters:
u (torch.Tensor) – Input tensor.
state (torch.Tensor) – Current state tensor.
- Returns:
y.real (torch.Tensor): Output tensor.
new_state (torch.Tensor): Updated state tensor.
- Return type:
A tuple containing
- forward_state(u, state)[source]¶
Forward the state through a sequence.
- Parameters:
u (torch.Tensor) – Input sequence tensor of shape
(B, H, L).state (torch.Tensor) – State tensor of shape
(B, H, N).
- Returns:
The updated state tensor.
- Return type:
- class S4DKernel[source]¶
Bases:
S4KernelBaseSSM kernel using diagonal state matrix (S4D model) - pure convolution operation.
- Parameters:
- forward(L, state=None, rate=1.0)[source]¶
Compute SSM convolution kernel - the core operation.
- Parameters:
L (int) – Sequence length.
state (torch.Tensor, optional) – State tensor. Defaults to None.
rate (float, optional) – Sampling rate. Defaults to 1.0.
- Returns:
- A tuple containing:
K : Convolution kernel.
K_state : Kernel state, if state is provided.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- default_state(*batch_shape)[source]¶
Create default state.
- Parameters:
*batch_shape – Variable length argument list for batch dimensions.
- Returns:
A zero-initialized state tensor.
- Return type:
- step(u, state)[source]¶
Single step operation.
- Parameters:
u (torch.Tensor) – Input tensor of shape
(B, H).state (torch.Tensor) – Current state tensor of shape
(B, H, N).
- Returns:
- A tuple containing:
y.real : Output tensor (scaled by 2).
next_state : Updated state tensor.
- Return type:
- forward_state(u, state)[source]¶
Pass state forward through sequence.
- Parameters:
u (torch.Tensor) – Input sequence tensor of shape
(B, H, L).state (torch.Tensor) – Initial state tensor of shape
(B, H, N).
- Returns:
The updated state tensor.
- Return type: