Source code for lrnnx.ops.s4_kernel_interface

import math
from collections import defaultdict
from functools import partial
from typing import Mapping, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from lrnnx.ops.s4_utils import (
    combination,
    get_cauchy_kernel,
    get_vandermonde_kernel,
    get_vandermonde_transpose_kernel,
    inv_transform,
    param_transform,
    power,
    process_dplr_params,
    process_ssm_params,
    setup_default_state,
)

# Function aliases
contract = torch.einsum
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10):
    _resolve_conj = lambda x: x.conj().resolve_conj()
else:
    _resolve_conj = lambda x: x.conj()

cauchy_k = get_cauchy_kernel()
vandermonde_k = get_vandermonde_kernel()
vandermonde_transpose_k = get_vandermonde_transpose_kernel()


[docs] class S4KernelBase(nn.Module): """ Base class for S4 kernels - receives parameters from the parent model. Args: d_model (int): Model dimension. l_max (int | None): Maximum sequence length. channels (int): Number of channels/heads. param_config (dict): A dictionary containing: * Parameter references: A_real, A_imag, B, C, inv_dt, P (nn.Parameters owned by S4/S4D) * Computed scalars: N, H, channels, rank, repeat * Config flags: dt_fast, real_transform, imag_transform, dt_transform, is_real, deterministic, verbose * S4D-only: disc """ def __init__( self, d_model: int, l_max: Optional[int], channels: int, param_config: dict, ): super().__init__() # parameter references (owned by the parent S4/S4D model) self.A_real = param_config["A_real"] self.A_imag = param_config.get("A_imag") # None when is_real=True self.B = param_config["B"] self.C = param_config["C"] self.P = param_config.get("P") # None for S4D (diagonal) self.inv_dt = param_config["inv_dt"] # derived dimensions (already computed by the parent) self.N = param_config["N"] # halved for conjugate symmetry self.H = param_config["H"] self.channels = param_config["channels"] self.rank = param_config["rank"] self.repeat = param_config["repeat"] # broadcast factor H // n_ssm # flags / transforms self.dt_fast = param_config["dt_fast"] self.real_transform = param_config["real_transform"] self.imag_transform = param_config["imag_transform"] self.dt_transform = param_config["dt_transform"] self.is_real = param_config["is_real"] self.deterministic = param_config["deterministic"] self.verbose = param_config["verbose"] # model geometry self.d_model = d_model self.L = self.l_max = l_max
[docs] class S4Kernel(S4KernelBase): """ SSM kernel for diagonal + low rank (DPLR) state matrices - pure convolution operation. Args: d_model (int): Model dimension. l_max (int | None): Maximum sequence length. channels (int): Number of channels/heads. param_config (dict): Configuration dictionary containing parameter references and flags. """ def __init__(self, d_model, l_max, channels, param_config): super().__init__(d_model, l_max, channels, param_config) self.register_buffer("l_kernel", torch.tensor(0))
[docs] def forward(self, state=None, rate=1.0, L=None): """ Compute SSM convolution kernel - the core operation. Args: state (torch.Tensor, optional): State tensor. Defaults to None. rate (float, optional): Sampling rate. Defaults to 1.0. L (int, optional): Sequence length. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor | None]: A tuple containing: - k_B : Convolution kernel. - k_state : Kernel state, if state is provided. """ # Initialize C~ if necessary if ( self.l_kernel.item() == 0 and self.l_max is not None and self.l_max > 0 ): self._setup_C(self.l_max) # Handle sampling rate logic if L is None: L = round(self.l_kernel.item() / rate) continuous_L = round(rate * L) while continuous_L > self.l_kernel.item(): self._setup_C(continuous_L) discrete_L = round(self.l_kernel.item() / rate) # Process parameters dt, A, B, C, P, Q = process_dplr_params( self.A_real, self.A_imag if not self.is_real else None, self.B, self.C, self.P, self.inv_dt, self.real_transform, self.imag_transform, self.dt_transform, self.dt_fast, self.is_real, self.repeat, rate, ) # Get FFT nodes omega, z = self._omega( discrete_L, dtype=A.dtype, device=A.device, cache=(rate == 1.0) ) # Augment B with state if state is not None: s = _conj(state) if state.size(-1) == self.N else state sA = s * _conj(A) - contract( "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) ) s = s / dt + sA / 2 s = s[..., : self.N] B = torch.cat([s, B], dim=-3) # Incorporate dt into A A = A * dt # Stack B and P, C and Q B = torch.cat([B, P], dim=-3) C = torch.cat([C, Q], dim=-3) # Incorporate B and C batch dimensions v = B.unsqueeze(-3) * C.unsqueeze(-4) v = v * dt # Calculate resolvent at omega r = cauchy_k(v, z, A) # Low-rank Woodbury correction k_f = self._woodbury_correction(r) # Final correction for bilinear transform k_f = k_f * 2 / (1 + omega) # Move from frequency to coefficients k = torch.fft.irfft(k_f, n=discrete_L) k = k[..., :L] if state is not None: k_state = k[:-1, :, :, :] else: k_state = None k_B = k[-1, :, :, :] return k_B, k_state
@torch.no_grad() def _setup_C(self, L): """Construct C~ from C.""" if self.l_kernel.item() == 0: if self.verbose: print(f"S4: Initializing kernel to length {L}") double_length = False elif L > self.l_kernel.item(): if self.verbose: print( f"S4: Doubling length from {self.l_kernel.item()} to {2*self.l_kernel.item()}" ) double_length = True L = self.l_kernel.item() else: return C = _r2c(self.C) dA, _ = self._setup_state() dA_L = power(L, dA) C_ = _conj(C) prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) if double_length: prod = -prod C_ = C_ - prod C_ = C_[..., : self.N] self.C.copy_(_c2r(C_)) self.l_kernel = ( 2 * self.l_kernel if double_length else self.l_kernel + L ) def _omega(self, L, dtype, device, cache=True): """Calculate (and cache) FFT nodes.""" if ( cache and hasattr(self, "omega") and self.omega.size(-1) == L // 2 + 1 ): return self.omega, self.z omega = torch.tensor( np.exp(-2j * np.pi / (L)), dtype=dtype, device=device ) omega = omega ** torch.arange(0, L // 2 + 1, device=device) z = 2 * (1 - omega) / (1 + omega) if cache: self.omega = omega self.z = z return omega, z def _woodbury_correction(self, r): """Apply low-rank Woodbury correction.""" if self.rank == 1: k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( 1 + r[-1:, -1:, :, :] ) elif self.rank == 2: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ :1, 1:, :, : ] * r11[1:, :1, :, :] s = ( r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] ) s = s / det k_f = r00 - s else: r00 = r[: -self.rank, : -self.rank, :, :] r01 = r[: -self.rank, -self.rank :, :, :] r10 = r[-self.rank :, : -self.rank, :, :] r11 = r[-self.rank :, -self.rank :, :, :] r11 = rearrange(r11, "a b h n -> h n a b") r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) r11 = rearrange(r11, "h n a b -> a b h n") k_f = r00 - torch.einsum( "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 ) return k_f @torch.no_grad() def double_length(self): """Double the sequence length representation.""" self._setup_C(2 * self.l_kernel) @torch.no_grad() def _setup_linear(self): """Preprocessing for fast linear-time stepping.""" dt, A, B, C, P, Q = process_dplr_params( self.A_real, self.A_imag if not self.is_real else None, self.B, self.C, self.P, self.inv_dt, self.real_transform, self.imag_transform, self.dt_transform, self.dt_fast, self.is_real, self.repeat, rate=1.0, ) D = (2.0 / dt - A).reciprocal() R = ( torch.eye(self.rank, dtype=A.dtype, device=A.device) + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real ) Q_D = rearrange(Q * D, "r h n -> h r n") try: R = torch.linalg.solve(R, Q_D) except: R = torch.tensor( np.linalg.solve( R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu(), ) ).to(Q_D) R = rearrange(R, "h r n -> r h n") self.step_params = { "D": D, "R": R, "P": P, "Q": Q, "B": B, "E": 2.0 / dt + A, } def _step_state_linear(self, u=None, state=None): """Linear-time step function.""" C = _r2c(self.C) if u is None: u = torch.zeros(self.H, dtype=C.dtype, device=C.device) if state is None: state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) step_params = self.step_params.copy() if state.size(-1) == self.N: contract_fn = lambda p, x, y: contract( "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y), )[..., : self.N] else: assert state.size(-1) == 2 * self.N step_params = {k: _conj(v) for k, v in step_params.items()} contract_fn = lambda p, x, y: contract( "r h n, r h m, ... h m -> ... h n", p, x, y ) D, E, R, P, Q, B = ( step_params["D"], step_params["E"], step_params["R"], step_params["P"], step_params["Q"], step_params["B"], ) new_state = E * state - contract_fn(P, Q, state) new_state = new_state + 2.0 * B * u.unsqueeze(-1) new_state = D * (new_state - contract_fn(P, R, new_state)) return new_state def _setup_state(self): """Construct dA and dB for discretized state equation.""" self._setup_linear() C = _r2c(self.C) state = torch.eye( 2 * self.N, dtype=C.dtype, device=C.device ).unsqueeze(-2) dA = self._step_state_linear(state=state) dA = rearrange(dA, "n h m -> h m n") u = C.new_ones(self.H) dB = self._step_state_linear(u=u) dB = _conj(dB) dB = rearrange(dB, "1 h n -> h n") return dA, dB def _step_state(self, u, state): """Quadratic step function.""" next_state = torch.einsum( self.state_contraction, self.dA, state ) + torch.einsum(self.input_contraction, self.dB, u) return next_state def _setup_step(self, mode="dense"): """Set up dA, dB, dC for stepping.""" # Ensure C has been transformed to C~ before we read it. # forward() does this automatically, but _setup_step can be called # directly (e.g. for manual recurrence in tests/inference) without # a prior forward pass. if ( self.l_kernel.item() == 0 and self.l_max is not None and self.l_max > 0 ): self._setup_C(self.l_max) self.dA, self.dB = self._setup_state() C = _conj(_r2c(self.C)) if self.l_kernel.item() == 0: dC = C else: dA_L = power(self.l_kernel.item(), self.dA) I = torch.eye(self.dA.size(-1)).to(dA_L) dC = torch.linalg.solve( I - dA_L.transpose(-1, -2), C.unsqueeze(-1) ).squeeze(-1) self.dC = dC self._step_mode = mode if mode == "linear": self.dC = 2 * self.dC[:, :, : self.N] elif mode == "diagonal": L, V = torch.linalg.eig(self.dA) V_inv = torch.linalg.inv(V) if self.verbose: print( "Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA), ) self.dA = L self.dB = contract("h n m, h m -> h n", V_inv, self.dB) self.dC = contract("h n m, c h n -> c h m", V, self.dC) elif mode == "dense": pass else: raise NotImplementedError( "Step mode must be {'dense' | 'linear' | 'diagonal'}" )
[docs] def default_state(self, *batch_shape): """ Create default state. Args: *batch_shape: Variable length argument list for batch dimensions. Returns: torch.Tensor: A zero-initialized state tensor. """ C = _r2c(self.C) N = C.size(-1) H = C.size(-2) step_mode = getattr(self, "_step_mode", "dense") if step_mode != "linear": N *= 2 if step_mode == "diagonal": self.state_contraction = "h n, ... h n -> ... h n" else: self.state_contraction = "h m n, ... h n -> ... h m" self.input_contraction = "h n, ... h -> ... h n" self.output_contraction = "c h n, ... h n -> ... c h" state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) return state
[docs] def step(self, u, state): """ Perform single step. Args: u (torch.Tensor): Input tensor. state (torch.Tensor): Current state tensor. Returns: A tuple containing: - y.real (torch.Tensor): Output tensor. - new_state (torch.Tensor): Updated state tensor. """ if self._step_mode == "linear": new_state = self._step_state_linear(u, state) else: new_state = self._step_state(u, state) y = torch.einsum(self.output_contraction, self.dC, new_state) return y.real, new_state
[docs] def forward_state(self, u, state): """ Forward the state through a sequence. Args: u (torch.Tensor): Input sequence tensor of shape ``(B, H, L)``. state (torch.Tensor): State tensor of shape ``(B, H, N)``. Returns: torch.Tensor: The updated state tensor. """ dA, dB = self._setup_state() conj = state.size(-1) != dA.size(-1) if conj: state = _conj(state) v = contract("h n, b h l -> b h n l", dB, u.flip(-1)) AL, v = power(u.size(-1), dA, v) next_state = contract("h m n, b h n -> b h m", AL, state) next_state = next_state + v if conj: next_state = next_state[..., : next_state.size(-1) // 2] return next_state
[docs] class S4DKernel(S4KernelBase): """ SSM kernel using diagonal state matrix (S4D model) - pure convolution operation. Args: d_model (int): Model dimension. l_max (int | None): Maximum sequence length. channels (int): Number of channels/heads. param_config (dict): Configuration dictionary containing parameter references and flags, including the S4D-specific 'disc' key. """ def __init__(self, d_model, l_max, channels, param_config): self.disc = param_config.get("disc", "zoh") super().__init__(d_model, l_max, channels, param_config)
[docs] def forward(self, L, state=None, rate=1.0): """ Compute SSM convolution kernel - the core operation. Args: L (int): Sequence length. state (torch.Tensor, optional): State tensor. Defaults to None. rate (float, optional): Sampling rate. Defaults to 1.0. Returns: tuple[torch.Tensor, torch.Tensor | None]: A tuple containing: - K : Convolution kernel. - K_state : Kernel state, if state is provided. """ # Process parameters dt, A, B, C, dtA = process_ssm_params( self.A_real, self.A_imag if not self.is_real else None, self.B, self.C, self.inv_dt, self.real_transform, self.imag_transform, self.dt_transform, self.dt_fast, self.is_real, self.repeat, rate, ) # Augment B with state if state is not None: s = state / dt if self.disc == "bilinear": s = s * (1.0 + dtA / 2) elif self.disc == "zoh": s = s * dtA * dtA.exp() / (dtA.exp() - 1.0) B = torch.cat([s, B], dim=-3) # Combine B and C C = (B[:, None, :, :] * C).view(-1, self.H, self.N) # Main kernel computation if self.disc == "zoh": C = C * (torch.exp(dtA) - 1.0) / A K = vandermonde_k(C, dtA, L) elif self.disc == "bilinear": C = C * (1.0 - dtA / 2).reciprocal() * dt dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) K = vandermonde_k(C, dA.log(), L) else: raise ValueError(f"Discretization {self.disc} not supported") K = K.view(-1, self.channels, self.H, L) if state is not None: K_state = K[:-1, :, :, :] else: K_state = None K = K[-1, :, :, :] return K, K_state
def _setup_step(self): """Set up dA, dB, dC for stepping.""" dt, A, B, C, dtA = process_ssm_params( self.A_real, self.A_imag if not self.is_real else None, self.B, self.C, self.inv_dt, self.real_transform, self.imag_transform, self.dt_transform, self.dt_fast, self.is_real, self.repeat, rate=1.0, ) if self.disc == "zoh": self.dA = torch.exp(dtA) self.dB = B * (torch.exp(dtA) - 1.0) / A elif self.disc == "bilinear": self.dA = (1.0 + dtA / 2) / (1.0 - dtA / 2) self.dB = B * (1.0 - dtA / 2).reciprocal() * dt self.dB = rearrange(self.dB, "1 h n -> h n") self.dC = C
[docs] def default_state(self, *batch_shape): """ Create default state. Args: *batch_shape: Variable length argument list for batch dimensions. Returns: torch.Tensor: A zero-initialized state tensor. """ C = _r2c(self.C) # For diagonal S4D, we don't need to double N - state is just (H, N) state = torch.zeros( *batch_shape, self.H, self.N, dtype=C.dtype, device=C.device ) return state
[docs] def step(self, u, state): """ Single step operation. Args: u (torch.Tensor): Input tensor of shape ``(B, H)``. state (torch.Tensor): Current state tensor of shape ``(B, H, N)``. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - y.real : Output tensor (scaled by 2). - next_state : Updated state tensor. """ next_state = contract( "h n, b h n -> b h n", self.dA, state ) + contract("h n, b h -> b h n", self.dB, u) y = contract("c h n, b h n -> b c h", self.dC, next_state) return 2 * y.real, next_state
[docs] def forward_state(self, u, state): """ Pass state forward through sequence. Args: u (torch.Tensor): Input sequence tensor of shape ``(B, H, L)``. state (torch.Tensor): Initial state tensor of shape ``(B, H, N)``. Returns: torch.Tensor: The updated state tensor. """ self._setup_step() AL = self.dA ** u.size(-1) u = u.flip(-1).to(self.dA).contiguous() v = vandermonde_transpose_k(u, self.dB, self.dA.log(), u.size(-1)) next_state = AL * state + v return next_state
kernel_registry = { "s4": S4Kernel, "s4d": S4DKernel, }