Source code for lrnnx.ops.simplified_scan

"""
Simplified SSM Scan for S5-style models.
This module implements connects model using a CUDA kernel
"""

from __future__ import annotations

import simplified_scan_cuda
import torch

from lrnnx.ops.torch import custom_bwd, custom_fwd


[docs] class SimplifiedScanFn(torch.autograd.Function): """ Autograd function for simplified SSM scan with complex input. B projects input from H-dim to P-dim (state dimension). The kernel operates in state space with identity B/C (diagonal SSM). C projects output from P-dim back to H-dim. Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L) """ @staticmethod @custom_fwd def forward( ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, return_last_state=False, ) -> torch.Tensor: """ Forward pass for the simplified SSM scan. Args: ctx (Any): Autograd context. u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``. H is hidden/input dimension. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``. P is state dimension. A (torch.Tensor): Complex eigenvalues tensor of shape ``(P,)`` or ``(P, 1)``. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``. Projects input to state space. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``. Projects state to output. deltaA (torch.Tensor | None): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``. discretization (str): Discretization method ('bilinear', 'zoh', 'dirac'). return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False. Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Output tensor, and optionally the last state. """ # Ensure contiguous if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if deltaA is not None and deltaA.stride(-1) != 1: deltaA = deltaA.contiguous() P = A.shape[0] # State dimension # Ensure A is (P, 1) for the kernel if A.dim() == 1: A_kernel = A.unsqueeze(-1) else: A_kernel = A dstate = A_kernel.shape[1] # Should be 1 for diagonal SSM # Modify input w/ projection (B, P, L) # this is needed because the kernel works in a SISO manner Bu = torch.einsum("ph,bhl->bpl", B, u) # Create identity B and C for the kernel (diagonal SSM) B_identity = torch.ones(P, dstate, dtype=A.dtype, device=A.device) C_identity = torch.ones(P, dstate, dtype=A.dtype, device=A.device) # Kernel forward pass # The naming is kept for consistency with the paper, the scan is not really # simplified, and works with complex inputs out, xs = simplified_scan_cuda.fwd( Bu, delta, A_kernel, B_identity, C_identity, deltaA, discretization, ) # Get the actual output by projecting back with C y = torch.einsum("hp,bpl->bhl", C, out) # Save for backward ctx.discretization = discretization ctx.has_deltaA = deltaA is not None ctx.save_for_backward( u, delta, A_kernel, B, C, Bu, out, xs, B_identity, C_identity, deltaA if deltaA is not None else torch.empty(0, device=u.device), ) last_state = ( out[:, -1, 1::2] + 1j * out[:, -1, 0::2] ) # (batch, P, dstate) if return_last_state: return y, last_state.squeeze(-1) # type: ignore[return-value] # (batch, P) return y @staticmethod @custom_bwd def backward(ctx, dout: torch.Tensor): """ Backward pass for the simplified SSM scan. Args: ctx (Any): Autograd context. dout (torch.Tensor): Gradient of the output tensor. Returns: tuple: Gradients with respect to inputs (du, ddelta, dA, dB, dC, ddeltaA, None, None). """ ( u, delta, A, B, C, Bu, out, xs, B_identity, C_identity, deltaA_saved, ) = ctx.saved_tensors # Restore deltaA (None if not provided) deltaA = deltaA_saved if ctx.has_deltaA else None # We assume the gradient is complex # in practice, it will not be. dx_out = torch.einsum("hp,bhl->bpl", C.conj(), dout) # dC = dout @ out^H summed over batch and seqlen dC = torch.einsum("bhl,bpl->hp", dout, out.conj()) if dx_out.stride(-1) != 1: dx_out = dx_out.contiguous() # Backward through kernel (with identity B/C) result = simplified_scan_cuda.bwd( Bu, delta, A, B_identity, C_identity, deltaA, dx_out, xs, ctx.discretization, ) # The CUDA function returns: dBu, ddelta, dA, dB_identity, dC_identity, [ddeltaA] # ddeltaA is present if deltaA was not None dBu, ddelta, dA = result[0], result[1], result[2] ddeltaA = None if ctx.has_deltaA and len(result) > 5: ddeltaA = result[5] # du = B^H @ dBu: (batch, H, seqlen) du = torch.einsum("ph,bpl->bhl", B.conj(), dBu) # dB = dBu @ u^H summed over batch and seqlen dB = torch.einsum("bpl,bhl->ph", dBu, u.conj()) # Squeeze dA if it was originally 1D if dA.dim() == 2 and dA.shape[1] == 1: dA = dA.squeeze(-1) return ( du, ddelta, dA, dB, dC, ddeltaA, None, None, )
[docs] def simplified_scan_fn( u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = "bilinear", ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Simplified SSM scan using CUDA kernel. S5-style scan where B projects input to state space and C projects back. The kernel internally operates with identity B/C (diagonal SSM). Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L) Args: u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``, dtype=complex64. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``, dtype=float32. A (torch.Tensor): Complex state matrix eigenvalues of shape ``(P,)`` or ``(P, 1)``, dtype=complex64. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype=complex64. Projects input to state. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype=complex64. Projects state to output. deltaA (torch.Tensor | None, optional): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``, dtype=float32. If provided, A is discretized using deltaA while B uses delta. Defaults to None. return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False. discretization (str, optional): Discretization method ('bilinear', 'zoh', 'dirac'). Defaults to "bilinear". Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - Complex output tensor of shape ``(batch, H, seqlen)``, dtype=complex64. - last_state : If return_last_state=True, returns state of shape ``(batch, P)``, dtype=complex64. """ # Validate discretization method valid_methods = ("bilinear", "zoh", "dirac") if discretization not in valid_methods: raise ValueError( f"discretization must be one of {valid_methods}, got '{discretization}'" ) # Use autograd function for gradient support return SimplifiedScanFn.apply( u, delta, A, B, C, deltaA, discretization, return_last_state )
[docs] def simplified_scan_ref( u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, deltaA: torch.Tensor | None = None, return_last_state: bool = False, discretization: str = "bilinear", ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Reference implementation of simplified scan (pure PyTorch). S5-style scan where B projects input to state space and C projects back. Forward: u (B,H,L) -> Bu = B @ u -> kernel -> x (B,P,L) -> y = C @ x -> (B,H,L) Args: u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``, dtype=complex64. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``, dtype=float32. A (torch.Tensor): Complex state matrix eigenvalues of shape ``(P,)`` or ``(P, 1)``, dtype=complex64. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype=complex64. Projects input to state. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype=complex64. Projects state to output. deltaA (torch.Tensor | None, optional): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``. If provided, A is discretized using deltaA while B uses delta. Defaults to None. return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False. discretization (str, optional): Discretization method ('bilinear', 'zoh', 'dirac'). Defaults to "bilinear". Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - out : Complex output tensor of shape ``(batch, H, seqlen)``, dtype=complex64. - last_state : If return_last_state=True, also returns state of shape ``(batch, P)``, dtype=complex64. """ dtype_in = u.dtype assert u.is_complex(), "Input u must be complex (complex64)" assert A.is_complex(), "State matrix A must be complex (complex64)" assert not delta.is_complex(), "Delta must be real (float32)" batch, H, seqlen = u.shape P = A.shape[0] # State dimension # Ensure A is 1D for this reference if A.dim() == 2: A = A.squeeze(-1) # Ensure delta is float delta = delta.float() # Modify input w/ projection Bu = torch.einsum("ph,bhl->bpl", B, u) # (batch, P, L) # Discretize A and B for each timestep Bu_expanded = Bu.unsqueeze(-1) # (b, P, L, 1) delta_expanded = delta.unsqueeze(-1) # (b, P, L, 1) # Use deltaA for A discretization if provided, otherwise use delta delta_for_A = ( deltaA.unsqueeze(-1) if deltaA is not None else delta_expanded ) A_expanded = ( A.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) ) # (P,) -> (1, P, 1, 1) one = torch.tensor(1.0, dtype=A.dtype, device=A.device) if discretization == "bilinear": # A_bar = (1 + 0.5*delta_for_A*A) / (1 - 0.5*delta_for_A*A) # B_bar = delta / (1 - 0.5*delta*A) (identity B inside kernel) a_half_A = 0.5 * delta_for_A * A_expanded # For A discretization a_half_B = 0.5 * delta_expanded * A_expanded # For B discretization denom_A = one - a_half_A denom_B = one - a_half_B A_bar = (one + a_half_A) / denom_A # (b, P, L, 1) deltaB_u = (delta_expanded / denom_B) * Bu_expanded # (b, P, L, 1) elif discretization == "zoh": # A_bar = exp(delta_for_A * A) # B_bar = (exp(delta * A) - 1) / A (identity B inside kernel) A_bar = torch.exp(delta_for_A * A_expanded) # (b, P, L, 1) A_disc_B = torch.exp(delta_expanded * A_expanded) deltaB_u = ( (A_disc_B - one) / A_expanded ) * Bu_expanded # (b, P, L, 1) elif discretization == "dirac": # A_bar = exp(delta_for_A * A) # B_bar = 1 (identity, no delta scaling) A_bar = torch.exp(delta_for_A * A_expanded) # (b, P, L, 1) deltaB_u = Bu_expanded # (b, P, L, 1) - no delta scaling else: raise ValueError(f"Unknown discretization method: {discretization}") # Sequential scan in state space x = torch.zeros((batch, P, 1), dtype=A.dtype, device=A.device) xs = [] for i in range(seqlen): x = A_bar[:, :, i] * x + deltaB_u[:, :, i] xs.append(x.squeeze(-1)) # (b, P) last_state = x.squeeze(-1) # (batch, P) x_seq = torch.stack(xs, dim=2) # (b, P, L) # Step 2: Project state to output: y = C @ x # x_seq: (batch, P, seqlen), C: (H, P) -> y: (batch, H, seqlen) y = torch.einsum("hp,bpl->bhl", C, x_seq) out = y.to(dtype_in) if return_last_state: return out, last_state return out
[docs] class S5InnerFn(torch.autograd.Function): """ The complete S5 model inner function with custom forward and backward. This wraps the simplified scan kernel and adds: 1. Conjugate symmetry handling (2 * real if conj_sym else real) 2. Skip connection with D matrix The scan kernel computes: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x This function then applies: out = (2 if conj_sym else 1) * Re(y) + D * u.real """ @staticmethod @custom_fwd def forward( ctx, u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None, discretization: str, conj_sym: bool, ): """ Forward pass for the complete S5 model inner function. Args: ctx (Any): Autograd context. u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``, dtype=complex64. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``, dtype=float32. A (torch.Tensor): Complex eigenvalues tensor of shape ``(P,)`` or ``(P, 1)``, dtype=complex64. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype=complex64. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype=complex64. D (torch.Tensor): Real skip connection tensor of shape ``(H,)``, dtype=float32. deltaA (torch.Tensor | None): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``. discretization (str): Discretization method ('bilinear', 'zoh', or 'dirac'). conj_sym (bool): If True, output is 2 * Re(y), else Re(y). Returns: torch.Tensor: Real output tensor of shape ``(batch, H, seqlen)``, dtype=float32. """ # Ensure contiguous if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if deltaA is not None and deltaA.stride(-1) != 1: deltaA = deltaA.contiguous() P = A.shape[0] # Ensure A is (P, 1) for the kernel if A.dim() == 1: A_kernel = A.unsqueeze(-1) else: A_kernel = A dstate = A_kernel.shape[1] # Project input to state space: Bu = B @ u -> (batch, P, L) Bu = torch.einsum("ph,bhl->bpl", B, u) # Create identity B and C for the kernel (diagonal SSM) B_identity = torch.ones(P, dstate, dtype=A.dtype, device=A.device) C_identity = torch.ones(P, dstate, dtype=A.dtype, device=A.device) # Kernel forward pass scan_out, xs = simplified_scan_cuda.fwd( Bu, delta, A_kernel, B_identity, C_identity, deltaA, discretization, ) # Project state to output: y_complex = C @ scan_out -> (batch, H, L) y_complex = torch.einsum("hp,bpl->bhl", C, scan_out) # Apply conjugate symmetry if conj_sym: y_real = 2 * y_complex.real else: y_real = y_complex.real # Apply skip connection: y = y_real + D * u.real # Note: u is complex, D operates on real part u_real = u.real if u.is_complex() else u y = y_real + D.unsqueeze(0).unsqueeze(-1) * u_real # Save for backward ctx.discretization = discretization ctx.conj_sym = conj_sym ctx.has_deltaA = deltaA is not None ctx.save_for_backward( u, delta, A_kernel, B, C, D, Bu, scan_out, xs, B_identity, C_identity, deltaA if deltaA is not None else torch.empty(0, device=u.device), ) return y @staticmethod @custom_bwd def backward(ctx, dout: torch.Tensor): """ Backward pass computing gradients for all inputs. Args: ctx (Any): Autograd context. dout (torch.Tensor): Gradient of loss w.r.t. output tensor of shape ``(batch, H, seqlen)``, real. Returns: tuple: Gradients for u, delta, A, B, C, D, deltaA, None, None. """ ( u, delta, A, B, C, D, Bu, scan_out, xs, B_identity, C_identity, deltaA_saved, ) = ctx.saved_tensors discretization = ctx.discretization conj_sym = ctx.conj_sym deltaA = deltaA_saved if ctx.has_deltaA else None # dD = sum(dout * u_real) over batch and seqlen u_real = u.real if u.is_complex() else u dD = torch.einsum("bhl,bhl->h", dout, u_real) du_skip = D.unsqueeze(0).unsqueeze(-1) * dout # (batch, H, L) # Gradient through conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y_complex) # dy_complex = (2 if conj_sym else 1) * dout (as complex with zero imaginary) scale = 2.0 if conj_sym else 1.0 dy_complex = torch.complex(scale * dout, torch.zeros_like(dout)) # dscan_out = C^H @ dy_complex dscan_out = torch.einsum("hp,bhl->bpl", C.conj(), dy_complex) # dC = dy_complex @ scan_out^H summed over batch and seqlen dC = torch.einsum("bhl,bpl->hp", dy_complex, scan_out.conj()) if dscan_out.stride(-1) != 1: dscan_out = dscan_out.contiguous() # Backward through kernel (with identity B/C) result = simplified_scan_cuda.bwd( Bu, delta, A, B_identity, C_identity, deltaA, dscan_out, xs, discretization, ) dBu, ddelta, dA = result[0], result[1], result[2] ddeltaA = None if ctx.has_deltaA and len(result) > 5: ddeltaA = result[5] # Gradient through B projection: Bu = B @ u # du_scan = B^H @ dBu du_scan = torch.einsum("ph,bpl->bhl", B.conj(), dBu) # dB = dBu @ u^H summed over batch and seqlen dB = torch.einsum("bpl,bhl->ph", dBu, u.conj()) # Total gradient for u (complex): du_scan + du_skip (as complex) if u.is_complex(): du = du_scan + torch.complex(du_skip, torch.zeros_like(du_skip)) else: du = du_scan.real + du_skip # Squeeze dA if it was originally 1D if dA.dim() == 2 and dA.shape[1] == 1: dA = dA.squeeze(-1) return ( du, ddelta, dA, dB, dC, dD, ddeltaA, None, # discretization None, # conj_sym )
[docs] def s5_inner_fn( u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = "bilinear", conj_sym: bool = True, ) -> torch.Tensor: """ S5 inner function using CUDA kernel. Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real Args: u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``, dtype=complex64. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``, dtype=float32. A (torch.Tensor): Complex eigenvalues tensor of shape ``(P,)`` or ``(P, 1)``, dtype=complex64. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype=complex64. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype=complex64. D (torch.Tensor): Real skip connection tensor of shape ``(H,)``, dtype=float32. deltaA (torch.Tensor | None, optional): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``. Defaults to None. discretization (str, optional): Discretization method ('bilinear', 'zoh', or 'dirac'). Defaults to "bilinear". conj_sym (bool, optional): If True, output is 2 * Re(y), else Re(y). Defaults to True. Returns: torch.Tensor: Real output tensor of shape ``(batch, H, seqlen)``, dtype=float32. """ return S5InnerFn.apply( u, delta, A, B, C, D, deltaA, discretization, conj_sym )
[docs] def s5_inner_ref( u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, deltaA: torch.Tensor | None = None, discretization: str = "bilinear", conj_sym: bool = True, ) -> torch.Tensor: """ Reference implementation of S5 inner function (pure PyTorch). Computes the complete S5 forward pass: 1. SSM scan: x[t] = A_bar * x[t-1] + B_bar * (B @ u)[t], y = C @ x 2. Conjugate symmetry: y_real = (2 if conj_sym else 1) * Re(y) 3. Skip connection: out = y_real + D * u.real Args: u (torch.Tensor): Complex input tensor of shape ``(batch, H, seqlen)``, dtype=complex64. delta (torch.Tensor): Real timestep tensor of shape ``(batch, P, seqlen)``, dtype=float32. A (torch.Tensor): Complex eigenvalues tensor of shape ``(P,)`` or ``(P, 1)``, dtype=complex64. B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype=complex64. C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype=complex64. D (torch.Tensor): Real skip connection tensor of shape ``(H,)``, dtype=float32. deltaA (torch.Tensor | None, optional): Optional separate timestep for A discretization of shape ``(batch, P, seqlen)``. Defaults to None. discretization (str, optional): Discretization method ('bilinear', 'zoh', or 'dirac'). Defaults to "bilinear". conj_sym (bool, optional): If True, output is 2 * Re(y), else Re(y). Defaults to True. Returns: torch.Tensor: Real output tensor of shape ``(batch, H, seqlen)``, dtype=float32. """ # Use simplified_scan_ref for the SSM computation y_complex = simplified_scan_ref( u, delta, A, B, C, deltaA=deltaA, return_last_state=False, discretization=discretization, ) # Apply conjugate symmetry if conj_sym: y_real = 2 * y_complex.real # type: ignore[union-attr] else: y_real = y_complex.real # type: ignore[union-attr] # Apply skip connection u_real = u.real if u.is_complex() else u out = y_real + D.unsqueeze(0).unsqueeze(-1) * u_real return out