lrnnx.ops.simplified_scan module¶
Simplified SSM Scan for S5-style models. This module implements connects model using a CUDA kernel
- class SimplifiedScanFn[source]¶
Bases:
FunctionAutograd function for simplified SSM scan with complex input.
B projects input from H-dim to P-dim (state dimension). The kernel operates in state space with identity B/C (diagonal SSM). C projects output from P-dim back to H-dim.
Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)
- static forward(ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, return_last_state=False) torch.Tensor¶
Forward pass for the simplified SSM scan.
- Parameters:
ctx (Any) – Autograd context.
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen). H is hidden/input dimension.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen). P is state dimension.A (torch.Tensor) – Complex eigenvalues tensor of shape
(P,)or(P, 1).B (torch.Tensor) – Complex projection matrix of shape
(P, H). Projects input to state space.C (torch.Tensor) – Complex projection matrix of shape
(H, P). Projects state to output.deltaA (torch.Tensor | None) – Optional separate timestep for A discretization of shape
(batch, P, seqlen).discretization (str) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’).
return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
- Returns:
Output tensor, and optionally the last state.
- Return type:
- static backward(ctx, dout: torch.Tensor)¶
Backward pass for the simplified SSM scan.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor.
- Returns:
Gradients with respect to inputs (du, ddelta, dA, dB, dC, ddeltaA, None, None).
- Return type:
- simplified_scan_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = 'bilinear') torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
Simplified SSM scan using CUDA kernel.
S5-style scan where B projects input to state space and C projects back. The kernel internally operates with identity B/C (diagonal SSM).
Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)
- Parameters:
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen), dtype=complex64.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen), dtype=float32.A (torch.Tensor) – Complex state matrix eigenvalues of shape
(P,)or(P, 1), dtype=complex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtype=complex64. Projects input to state.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtype=complex64. Projects state to output.deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape
(batch, P, seqlen), dtype=float32. If provided, A is discretized using deltaA while B uses delta. Defaults to None.return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’). Defaults to “bilinear”.
- Returns:
Complex output tensor of shape
(batch, H, seqlen), dtype=complex64.last_state : If return_last_state=True, returns state of shape
(batch, P), dtype=complex64.
- Return type:
- simplified_scan_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = 'bilinear') torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
Reference implementation of simplified scan (pure PyTorch).
S5-style scan where B projects input to state space and C projects back.
Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L)
- Parameters:
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen), dtype=complex64.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen), dtype=float32.A (torch.Tensor) – Complex state matrix eigenvalues of shape
(P,)or(P, 1), dtype=complex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtype=complex64. Projects input to state.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtype=complex64. Projects state to output.deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape
(batch, P, seqlen). If provided, A is discretized using deltaA while B uses delta. Defaults to None.return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, ‘dirac’). Defaults to “bilinear”.
- Returns:
out : Complex output tensor of shape
(batch, H, seqlen), dtype=complex64.last_state : If return_last_state=True, also returns state of shape
(batch, P), dtype=complex64.
- Return type:
- class S5InnerFn[source]¶
Bases:
FunctionThe complete S5 model inner function with custom forward and backward.
This wraps the simplified scan kernel and adds: 1. Conjugate symmetry handling (2 * real if conj_sym else real) 2. Skip connection with D matrix
The scan kernel computes: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x This function then applies: out = (2 if conj_sym else 1) * Re(y) + D * u.real
- static forward(ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, conj_sym: bool)¶
Forward pass for the complete S5 model inner function.
- Parameters:
ctx (Any) – Autograd context.
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen), dtype=complex64.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen), dtype=float32.A (torch.Tensor) – Complex eigenvalues tensor of shape
(P,)or(P, 1), dtype=complex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtype=complex64.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtype=complex64.D (torch.Tensor) – Real skip connection tensor of shape
(H,), dtype=float32.deltaA (torch.Tensor | None) – Optional separate timestep for A discretization of shape
(batch, P, seqlen).discretization (str) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’).
conj_sym (bool) – If True, output is 2 * Re(y), else Re(y).
- Returns:
Real output tensor of shape
(batch, H, seqlen), dtype=float32.- Return type:
- static backward(ctx, dout: torch.Tensor)¶
Backward pass computing gradients for all inputs.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of loss w.r.t. output tensor of shape
(batch, H, seqlen), real.
- Returns:
Gradients for u, delta, A, B, C, D, deltaA, None, None.
- Return type:
- s5_inner_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]¶
S5 inner function using CUDA kernel.
Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real
- Parameters:
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen), dtype=complex64.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen), dtype=float32.A (torch.Tensor) – Complex eigenvalues tensor of shape
(P,)or(P, 1), dtype=complex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtype=complex64.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtype=complex64.D (torch.Tensor) – Real skip connection tensor of shape
(H,), dtype=float32.deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape
(batch, P, seqlen). Defaults to None.discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.
conj_sym (bool, optional) – If True, output is 2 * Re(y), else Re(y). Defaults to True.
- Returns:
Real output tensor of shape
(batch, H, seqlen), dtype=float32.- Return type:
- s5_inner_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]¶
Reference implementation of S5 inner function (pure PyTorch).
Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real
- Parameters:
u (torch.Tensor) – Complex input tensor of shape
(batch, H, seqlen), dtype=complex64.delta (torch.Tensor) – Real timestep tensor of shape
(batch, P, seqlen), dtype=float32.A (torch.Tensor) – Complex eigenvalues tensor of shape
(P,)or(P, 1), dtype=complex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtype=complex64.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtype=complex64.D (torch.Tensor) – Real skip connection tensor of shape
(H,), dtype=float32.deltaA (torch.Tensor | None, optional) – Optional separate timestep for A discretization of shape
(batch, P, seqlen). Defaults to None.discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.
conj_sym (bool, optional) – If True, output is 2 * Re(y), else Re(y). Defaults to True.
- Returns:
Real output tensor of shape
(batch, H, seqlen), dtype=float32.- Return type: