lrnnx.ops.simplified_scan module

Simplified SSM Scan for S5-style models. This module implements connects model using a CUDA kernel

class SimplifiedScanFn[source]

Bases: Function

Autograd function for simplified SSM scan with complex input.

B projects input from H-dim to P-dim (state dimension). The kernel operates in state space with identity B/C (diagonal SSM). C projects output from P-dim back to H-dim.

Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)

static forward(ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, return_last_state=False) torch.Tensor

Forward pass for the simplified SSM scan.

Parameters:
  • ctx (Any) – Autograd context.

  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen). H is hidden/input dimension.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen). P is state dimension.

  • A (torch.Tensor) – Complex eigenvalues tensor of shape (P,) or (P, 1).

  • B (torch.Tensor) – Complex projection matrix of shape (P, H). Projects input to state space.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P). Projects state to output.

  • deltaA (torch.Tensor | None) – Optional separate timestep for A discretization of shape (batch, P, seqlen).

  • discretization (str) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’).

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

Returns:

Output tensor, and optionally the last state.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

static backward(ctx, dout: torch.Tensor)

Backward pass for the simplified SSM scan.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of the output tensor.

Returns:

Gradients with respect to inputs (du, ddelta, dA, dB, dC, ddeltaA, None, None).

Return type:

tuple

simplified_scan_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = 'bilinear') torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

Simplified SSM scan using CUDA kernel.

S5-style scan where B projects input to state space and C projects back. The kernel internally operates with identity B/C (diagonal SSM).

Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)

Parameters:
  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen), dtype=complex64.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen), dtype=float32.

  • A (torch.Tensor) – Complex state matrix eigenvalues of shape (P,) or (P, 1), dtype=complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype=complex64. Projects input to state.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype=complex64. Projects state to output.

  • deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape (batch, P, seqlen), dtype=float32. If provided, A is discretized using deltaA while B uses delta. Defaults to None.

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’). Defaults to “bilinear”.

Returns:

  • Complex output tensor of shape (batch, H, seqlen), dtype=complex64.

  • last_state : If return_last_state=True, returns state of shape (batch, P), dtype=complex64.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

simplified_scan_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = 'bilinear') torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

Reference implementation of simplified scan (pure PyTorch).

S5-style scan where B projects input to state space and C projects back.

Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)

Parameters:
  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen), dtype=complex64.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen), dtype=float32.

  • A (torch.Tensor) – Complex state matrix eigenvalues of shape (P,) or (P, 1), dtype=complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype=complex64. Projects input to state.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype=complex64. Projects state to output.

  • deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape (batch, P, seqlen). If provided, A is discretized using deltaA while B uses delta. Defaults to None.

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’). Defaults to “bilinear”.

Returns:

  • out : Complex output tensor of shape (batch, H, seqlen), dtype=complex64.

  • last_state : If return_last_state=True, also returns state of shape (batch, P), dtype=complex64.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class S5InnerFn[source]

Bases: Function

The complete S5 model inner function with custom forward and backward.

This wraps the simplified scan kernel and adds: 1. Conjugate symmetry handling (2 * real if conj_sym else real) 2. Skip connection with D matrix

The scan kernel computes: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x This function then applies: out = (2 if conj_sym else 1) * Re(y) + D * u.real

static forward(ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, conj_sym: bool)

Forward pass for the complete S5 model inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen), dtype=complex64.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen), dtype=float32.

  • A (torch.Tensor) – Complex eigenvalues tensor of shape (P,) or (P, 1), dtype=complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype=complex64.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype=complex64.

  • D (torch.Tensor) – Real skip connection tensor of shape (H,), dtype=float32.

  • deltaA (torch.Tensor | None) – Optional separate timestep for A discretization of shape (batch, P, seqlen).

  • discretization (str) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’).

  • conj_sym (bool) – If True, output is 2 * Re(y), else Re(y).

Returns:

Real output tensor of shape (batch, H, seqlen), dtype=float32.

Return type:

torch.Tensor

static backward(ctx, dout: torch.Tensor)

Backward pass computing gradients for all inputs.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of loss w.r.t. output tensor of shape (batch, H, seqlen), real.

Returns:

Gradients for u, delta, A, B, C, D, deltaA, None, None.

Return type:

tuple

s5_inner_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]

S5 inner function using CUDA kernel.

Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real

Parameters:
  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen), dtype=complex64.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen), dtype=float32.

  • A (torch.Tensor) – Complex eigenvalues tensor of shape (P,) or (P, 1), dtype=complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype=complex64.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype=complex64.

  • D (torch.Tensor) – Real skip connection tensor of shape (H,), dtype=float32.

  • deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape (batch, P, seqlen). Defaults to None.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.

  • conj_sym (bool, optional) – If True, output is 2 * Re(y), else Re(y). Defaults to True.

Returns:

Real output tensor of shape (batch, H, seqlen), dtype=float32.

Return type:

torch.Tensor

s5_inner_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]

Reference implementation of S5 inner function (pure PyTorch).

Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real

Parameters:
  • u (torch.Tensor) – Complex input tensor of shape (batch, H, seqlen), dtype=complex64.

  • delta (torch.Tensor) – Real timestep tensor of shape (batch, P, seqlen), dtype=float32.

  • A (torch.Tensor) – Complex eigenvalues tensor of shape (P,) or (P, 1), dtype=complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype=complex64.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype=complex64.

  • D (torch.Tensor) – Real skip connection tensor of shape (H,), dtype=float32.

  • deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape (batch, P, seqlen). Defaults to None.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.

  • conj_sym (bool, optional) – If True, output is 2 * Re(y), else Re(y). Defaults to True.

Returns:

Real output tensor of shape (batch, H, seqlen), dtype=float32.

Return type:

torch.Tensor