lrnnx.ops.s4_kernel_interface module

class S4KernelBase[source]

Bases: Module

Base class for S4 kernels - receives parameters from the parent model.

Parameters:
  • d_model (int) – Model dimension.

  • l_max (int | None) – Maximum sequence length.

  • channels (int) – Number of channels/heads.

  • param_config (dict) –

    A dictionary containing:

    • Parameter references: A_real, A_imag, B, C, inv_dt, P (nn.Parameters owned by S4/S4D)

    • Computed scalars: N, H, channels, rank, repeat

    • Config flags: dt_fast, real_transform, imag_transform, dt_transform, is_real, deterministic, verbose

    • S4D-only: disc

class S4Kernel[source]

Bases: S4KernelBase

SSM kernel for diagonal + low rank (DPLR) state matrices - pure convolution operation.

Parameters:
  • d_model (int) – Model dimension.

  • l_max (int | None) – Maximum sequence length.

  • channels (int) – Number of channels/heads.

  • param_config (dict) – Configuration dictionary containing parameter references and flags.

forward(state=None, rate=1.0, L=None)[source]

Compute SSM convolution kernel - the core operation.

Parameters:
  • state (torch.Tensor, optional) – State tensor. Defaults to None.

  • rate (float, optional) – Sampling rate. Defaults to 1.0.

  • L (int, optional) – Sequence length. Defaults to None.

Returns:

A tuple containing:
  • k_B : Convolution kernel.

  • k_state : Kernel state, if state is provided.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

double_length()

Double the sequence length representation.

default_state(*batch_shape)[source]

Create default state.

Parameters:

*batch_shape – Variable length argument list for batch dimensions.

Returns:

A zero-initialized state tensor.

Return type:

torch.Tensor

step(u, state)[source]

Perform single step.

Parameters:
Returns:

  • y.real (torch.Tensor): Output tensor.

  • new_state (torch.Tensor): Updated state tensor.

Return type:

A tuple containing

forward_state(u, state)[source]

Forward the state through a sequence.

Parameters:
  • u (torch.Tensor) – Input sequence tensor of shape (B, H, L).

  • state (torch.Tensor) – State tensor of shape (B, H, N).

Returns:

The updated state tensor.

Return type:

torch.Tensor

class S4DKernel[source]

Bases: S4KernelBase

SSM kernel using diagonal state matrix (S4D model) - pure convolution operation.

Parameters:
  • d_model (int) – Model dimension.

  • l_max (int | None) – Maximum sequence length.

  • channels (int) – Number of channels/heads.

  • param_config (dict) – Configuration dictionary containing parameter references and flags, including the S4D-specific ‘disc’ key.

forward(L, state=None, rate=1.0)[source]

Compute SSM convolution kernel - the core operation.

Parameters:
  • L (int) – Sequence length.

  • state (torch.Tensor, optional) – State tensor. Defaults to None.

  • rate (float, optional) – Sampling rate. Defaults to 1.0.

Returns:

A tuple containing:
  • K : Convolution kernel.

  • K_state : Kernel state, if state is provided.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

default_state(*batch_shape)[source]

Create default state.

Parameters:

*batch_shape – Variable length argument list for batch dimensions.

Returns:

A zero-initialized state tensor.

Return type:

torch.Tensor

step(u, state)[source]

Single step operation.

Parameters:
Returns:

A tuple containing:
  • y.real : Output tensor (scaled by 2).

  • next_state : Updated state tensor.

Return type:

tuple[torch.Tensor, torch.Tensor]

forward_state(u, state)[source]

Pass state forward through sequence.

Parameters:
  • u (torch.Tensor) – Input sequence tensor of shape (B, H, L).

  • state (torch.Tensor) – Initial state tensor of shape (B, H, N).

Returns:

The updated state tensor.

Return type:

torch.Tensor