"""
Triton kernel for single-step state update of S5-style SSMs.
This is to simplified_scan what selective_state_update is to selective_scan.
"""
from __future__ import annotations
import torch
import triton
import triton.language as tl
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_DELTAA": lambda args: args["deltaA_ptr"] is not None})
@triton.jit
def _simplified_state_update_kernel(
# Pointers
state_re_ptr,
state_im_ptr,
x_ptr,
dt_ptr,
A_re_ptr,
A_im_ptr,
B_re_ptr,
B_im_ptr,
C_re_ptr,
C_im_ptr,
D_ptr,
deltaA_ptr,
out_ptr,
# Dimensions
batch,
H, # input / output dim
P, # state dim
# Strides - state (batch, P)
stride_state_batch,
stride_state_p,
# Strides - x (batch, H)
stride_x_batch,
stride_x_h,
# Strides - dt (batch, P) or (P,)
stride_dt_batch,
stride_dt_p,
# Strides - A (P,)
stride_A_p,
# Strides - B (P, H)
stride_B_p,
stride_B_h,
# Strides - C (H, P)
stride_C_h,
stride_C_p,
# Strides - D (H, H) or None
stride_D_h_out,
stride_D_h_in,
# Strides - deltaA (batch, P) or None
stride_deltaA_batch,
stride_deltaA_p,
# Strides - out (batch, H)
stride_out_batch,
stride_out_h,
# Meta
CONJ_SYM: tl.constexpr,
BLOCK_SIZE_P: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
HAS_D: tl.constexpr,
HAS_DELTAA: tl.constexpr,
DISCRETIZATION: tl.constexpr, # 0=bilinear, 1=zoh, 2=dirac
):
"""
Triton JIT kernel for the simplified state update.
Each program handles one batch element and a tile of H output dims.
"""
pid_b = tl.program_id(axis=0)
pid_h = tl.program_id(axis=1)
offs_p = tl.arange(0, BLOCK_SIZE_P) # state dim
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(
0, BLOCK_SIZE_H
) # output dim tile
mask_p = offs_p < P
mask_h = offs_h < H
state_re_ptrs = (
state_re_ptr + pid_b * stride_state_batch + offs_p * stride_state_p
)
state_im_ptrs = (
state_im_ptr + pid_b * stride_state_batch + offs_p * stride_state_p
)
s_re = tl.load(state_re_ptrs, mask=mask_p, other=0.0).to(tl.float32)
s_im = tl.load(state_im_ptrs, mask=mask_p, other=0.0).to(tl.float32)
x_ptrs = x_ptr + pid_b * stride_x_batch + offs_h * stride_x_h
x_val = tl.load(x_ptrs, mask=mask_h, other=0.0).to(
tl.float32
) # (BLOCK_SIZE_H,)
A_re = tl.load(A_re_ptr + offs_p * stride_A_p, mask=mask_p, other=0.0).to(
tl.float32
)
A_im = tl.load(A_im_ptr + offs_p * stride_A_p, mask=mask_p, other=0.0).to(
tl.float32
)
dt = tl.load(
dt_ptr + pid_b * stride_dt_batch + offs_p * stride_dt_p,
mask=mask_p,
other=0.0,
).to(tl.float32)
if HAS_DELTAA:
dtA = tl.load(
deltaA_ptr
+ pid_b * stride_deltaA_batch
+ offs_p * stride_deltaA_p,
mask=mask_p,
other=0.0,
).to(tl.float32)
else:
dtA = dt
# Discretize A: A_bar = discretize(A_complex, dtA)
if DISCRETIZATION == 0: # bilinear
# A_bar = (1 + 0.5*dtA*A) / (1 - 0.5*dtA*A)
half_dtA = 0.5 * dtA
num_re = 1.0 + half_dtA * A_re
num_im = half_dtA * A_im
den_re = 1.0 - half_dtA * A_re
den_im = -half_dtA * A_im
# complex division: (num_re + j*num_im) / (den_re + j*den_im)
den_mag_sq = den_re * den_re + den_im * den_im
den_mag_sq = tl.where(den_mag_sq == 0.0, 1e-12, den_mag_sq)
A_bar_re = (num_re * den_re + num_im * den_im) / den_mag_sq
A_bar_im = (num_im * den_re - num_re * den_im) / den_mag_sq
elif DISCRETIZATION == 1: # zoh
# A_bar = exp(dtA * A) where A is complex
# exp(a + jb) = exp(a) * (cos(b) + j*sin(b))
exp_real = tl.exp(dtA * A_re)
angle = dtA * A_im
A_bar_re = exp_real * tl.cos(angle)
A_bar_im = exp_real * tl.sin(angle)
else: # dirac (DISCRETIZATION == 2)
exp_real = tl.exp(dtA * A_re)
angle = dtA * A_im
A_bar_re = exp_real * tl.cos(angle)
A_bar_im = exp_real * tl.sin(angle)
# Discretize B: B_bar scalar per state dim
if DISCRETIZATION == 0: # bilinear
# gamma_bar = dt / (1 - 0.5*dt*A) (complex)
half_dt = 0.5 * dt
gden_re = 1.0 - half_dt * A_re
gden_im = -half_dt * A_im
gden_mag_sq = gden_re * gden_re + gden_im * gden_im
gden_mag_sq = tl.where(gden_mag_sq == 0.0, 1e-12, gden_mag_sq)
# dt is real, so numerator is (dt, 0)
gamma_re = (dt * gden_re) / gden_mag_sq
gamma_im = (-dt * gden_im) / gden_mag_sq # 0*gden_re - dt*gden_im
elif DISCRETIZATION == 1: # zoh
# gamma_bar = (exp(dt*A) - 1) / A (complex)
exp_re = tl.exp(dt * A_re)
ang = dt * A_im
expm1_re = exp_re * tl.cos(ang) - 1.0
expm1_im = exp_re * tl.sin(ang)
# divide by A (complex): (expm1) / (A_re + j*A_im)
A_mag_sq = A_re * A_re + A_im * A_im
A_mag_sq = tl.where(A_mag_sq == 0.0, 1e-12, A_mag_sq)
gamma_re = (expm1_re * A_re + expm1_im * A_im) / A_mag_sq
gamma_im = (expm1_im * A_re - expm1_re * A_im) / A_mag_sq
else: # dirac
gamma_re = tl.full(offs_p.shape, 1.0, dtype=tl.float32)
gamma_im = tl.full(offs_p.shape, 0.0, dtype=tl.float32)
# Compute Bu = B @ x (complex (P,) result)
# We accumulate over H tiles
Bu_re = tl.zeros((BLOCK_SIZE_P,), dtype=tl.float32)
Bu_im = tl.zeros((BLOCK_SIZE_P,), dtype=tl.float32)
for h_start in range(0, H, BLOCK_SIZE_H):
h_offs = h_start + tl.arange(0, BLOCK_SIZE_H)
h_mask = h_offs < H
x_tile = tl.load(
x_ptr + pid_b * stride_x_batch + h_offs * stride_x_h,
mask=h_mask,
other=0.0,
).to(
tl.float32
) # (BLOCK_SIZE_H,)
# B_re/im: (P, H) -> load tile (BLOCK_SIZE_P, BLOCK_SIZE_H)
B_re_tile = tl.load(
B_re_ptr
+ offs_p[:, None] * stride_B_p
+ h_offs[None, :] * stride_B_h,
mask=mask_p[:, None] & h_mask[None, :],
other=0.0,
).to(tl.float32)
B_im_tile = tl.load(
B_im_ptr
+ offs_p[:, None] * stride_B_p
+ h_offs[None, :] * stride_B_h,
mask=mask_p[:, None] & h_mask[None, :],
other=0.0,
).to(tl.float32)
Bu_re += tl.sum(B_re_tile * x_tile[None, :], axis=1)
Bu_im += tl.sum(B_im_tile * x_tile[None, :], axis=1)
Bbaru_re = gamma_re * Bu_re - gamma_im * Bu_im
Bbaru_im = gamma_re * Bu_im + gamma_im * Bu_re
new_s_re = A_bar_re * s_re - A_bar_im * s_im + Bbaru_re
new_s_im = A_bar_re * s_im + A_bar_im * s_re + Bbaru_im
tl.store(state_re_ptrs, new_s_re, mask=mask_p)
tl.store(state_im_ptrs, new_s_im, mask=mask_p)
C_re_tile = tl.load(
C_re_ptr + offs_h[:, None] * stride_C_h + offs_p[None, :] * stride_C_p,
mask=mask_h[:, None] & mask_p[None, :],
other=0.0,
).to(tl.float32)
C_im_tile = tl.load(
C_im_ptr + offs_h[:, None] * stride_C_h + offs_p[None, :] * stride_C_p,
mask=mask_h[:, None] & mask_p[None, :],
other=0.0,
).to(tl.float32)
y_tile = tl.sum(
C_re_tile * new_s_re[None, :] - C_im_tile * new_s_im[None, :], axis=1
)
if CONJ_SYM:
y_tile = 2.0 * y_tile
if HAS_D:
Dx = tl.zeros((BLOCK_SIZE_H,), dtype=tl.float32)
for h_start in range(0, H, BLOCK_SIZE_H):
h_offs = h_start + tl.arange(0, BLOCK_SIZE_H)
h_mask_inner = h_offs < H
x_tile = tl.load(
x_ptr + pid_b * stride_x_batch + h_offs * stride_x_h,
mask=h_mask_inner,
other=0.0,
).to(tl.float32)
D_tile = tl.load(
D_ptr
+ offs_h[:, None] * stride_D_h_out
+ h_offs[None, :] * stride_D_h_in,
mask=mask_h[:, None] & h_mask_inner[None, :],
other=0.0,
).to(tl.float32)
Dx += tl.sum(D_tile * x_tile[None, :], axis=1)
y_tile += Dx
out_ptrs = out_ptr + pid_b * stride_out_batch + offs_h * stride_out_h
tl.store(out_ptrs, y_tile, mask=mask_h)
[docs]
def simplified_state_update(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor | None = None,
deltaA: torch.Tensor | None = None,
discretization: str = "bilinear",
conj_sym: bool = True,
) -> torch.Tensor:
"""
Triton-accelerated single-step state update for S5-style (simplified) SSMs.
Args:
state (torch.Tensor): Complex hidden state of shape ``(batch, P)``, dtype ``complex64``. **Modified in-place**.
x (torch.Tensor): Real input at current timestep of shape ``(batch, H)``, dtype ``float32``.
dt (torch.Tensor): Real timestep of shape ``(batch, P)`` or ``(P,)``, dtype ``float32``.
A (torch.Tensor): Complex eigenvalues of shape ``(P,)``, dtype ``complex64``.
B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype ``complex64``.
C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype ``complex64``.
D (torch.Tensor, optional): Real skip connection matrix of shape ``(H, H)``, dtype ``float32``. Defaults to None.
deltaA (torch.Tensor, optional): Optional separate timestep for A discretization of shape ``(batch, P)``, dtype ``float32``. If None, ``dt`` is used for both A and B. Defaults to None.
discretization (str, optional): Discretization method ('bilinear', 'zoh', or 'dirac'). Defaults to "bilinear".
conj_sym (bool, optional): If True, output is 2 * Re(...), else Re(...). Defaults to True.
Returns:
torch.Tensor: Real output tensor of shape ``(batch, H)``, dtype ``float32``.
"""
assert state.is_complex(), "state must be complex64"
assert A.is_complex(), "A must be complex64"
assert B.is_complex(), "B must be complex64"
assert C.is_complex(), "C must be complex64"
assert not x.is_complex(), "x must be real (float32)"
batch, P = state.shape
H = x.shape[1]
assert x.shape == (batch, H)
assert A.shape == (P,)
assert B.shape == (P, H)
assert C.shape == (H, P)
if D is not None:
assert D.shape == (H, H)
if dt.dim() == 1:
assert dt.shape == (P,)
dt = dt.unsqueeze(0).expand(batch, -1).contiguous()
assert dt.shape == (batch, P)
if deltaA is not None:
if deltaA.dim() == 1:
deltaA = deltaA.unsqueeze(0).expand(batch, -1).contiguous()
assert deltaA.shape == (batch, P)
disc_map = {"bilinear": 0, "zoh": 1, "dirac": 2}
disc_int = disc_map.get(discretization.lower())
if disc_int is None:
raise ValueError(
f"discretization must be one of {list(disc_map)}, got '{discretization}'"
)
# View complex tensors as their underlying float storage (real, imag contiguous pairs)
# complex64 -> view_as_real -> (..., 2) float32
state_ri = torch.view_as_real(state) # (batch, P, 2)
A_ri = torch.view_as_real(A) # (P, 2)
B_ri = torch.view_as_real(B) # (P, H, 2)
C_ri = torch.view_as_real(C) # (H, P, 2)
state_re = state_ri[..., 0].contiguous() # (batch, P)
state_im = state_ri[..., 1].contiguous() # (batch, P)
A_re = A_ri[..., 0].contiguous() # (P,)
A_im = A_ri[..., 1].contiguous() # (P,)
B_re = B_ri[..., 0].contiguous() # (P, H)
B_im = B_ri[..., 1].contiguous() # (P, H)
C_re = C_ri[..., 0].contiguous() # (H, P)
C_im = C_ri[..., 1].contiguous() # (H, P)
x = x.contiguous()
dt = dt.contiguous()
out = torch.empty(batch, H, device=x.device, dtype=torch.float32)
# Determine block sizes
BLOCK_SIZE_P = triton.next_power_of_2(P)
BLOCK_SIZE_H = min(triton.next_power_of_2(H), 128)
grid = (batch, triton.cdiv(H, BLOCK_SIZE_H))
with torch.cuda.device(x.device.index):
_simplified_state_update_kernel[grid](
state_re,
state_im,
x,
dt,
A_re,
A_im,
B_re,
B_im,
C_re,
C_im,
D,
deltaA,
out,
# dims
batch,
H,
P,
# state strides
state_re.stride(0),
state_re.stride(1),
# x strides
x.stride(0),
x.stride(1),
# dt strides
dt.stride(0),
dt.stride(1),
# A strides
A_re.stride(0),
# B strides
B_re.stride(0),
B_re.stride(1),
# C strides
C_re.stride(0),
C_re.stride(1),
# D strides
*(D.stride(0), D.stride(1)) if D is not None else (0, 0),
# deltaA strides
*(
(deltaA.stride(0), deltaA.stride(1))
if deltaA is not None
else (0, 0)
),
# out strides
out.stride(0),
out.stride(1),
# meta
CONJ_SYM=conj_sym,
BLOCK_SIZE_P=BLOCK_SIZE_P,
BLOCK_SIZE_H=BLOCK_SIZE_H,
DISCRETIZATION=disc_int,
)
# Write updated state back into the original complex tensor
state.real.copy_(state_re)
state.imag.copy_(state_im)
return out
[docs]
def simplified_state_update_ref(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor | None = None,
deltaA: torch.Tensor | None = None,
discretization: str = "bilinear",
conj_sym: bool = True,
) -> torch.Tensor:
"""
Pure-PyTorch reference for a single-step S5 state update.
Args:
state (torch.Tensor): Complex hidden state of shape ``(batch, P)``, dtype ``complex64``. **Modified in-place**.
x (torch.Tensor): Real input at current timestep of shape ``(batch, H)``, dtype ``float32``.
dt (torch.Tensor): Real timestep of shape ``(batch, P)`` or ``(P,)``, dtype ``float32``.
A (torch.Tensor): Complex eigenvalues of shape ``(P,)``, dtype ``complex64``.
B (torch.Tensor): Complex projection matrix of shape ``(P, H)``, dtype ``complex64``.
C (torch.Tensor): Complex projection matrix of shape ``(H, P)``, dtype ``complex64``.
D (torch.Tensor, optional): Real skip connection matrix of shape ``(H, H)``, dtype ``float32``. Defaults to None.
deltaA (torch.Tensor, optional): Optional separate timestep for A discretization of shape ``(batch, P)``, dtype ``float32``. Defaults to None.
discretization (str, optional): Discretization method ('bilinear', 'zoh', or 'dirac'). Defaults to "bilinear".
conj_sym (bool, optional): If True, output is 2 * Re(...), else Re(...). Defaults to True.
Returns:
torch.Tensor: Real output tensor of shape ``(batch, H)``, dtype ``float32``.
"""
assert state.is_complex()
assert A.is_complex()
assert B.is_complex()
assert C.is_complex()
batch, P = state.shape
H = x.shape[1]
# Expand dt / deltaA to (batch, P) if needed
if dt.dim() == 1:
dt = dt.unsqueeze(0).expand(batch, -1)
dtA = deltaA if deltaA is not None else dt
if dtA.dim() == 1:
dtA = dtA.unsqueeze(0).expand(batch, -1)
A_complex = A # (P,)
if discretization.lower() == "bilinear":
half_dtA_A = 0.5 * dtA * A_complex # (batch, P)
A_bar = (1.0 + half_dtA_A) / (1.0 - half_dtA_A)
elif discretization.lower() == "zoh":
A_bar = torch.exp(dtA * A_complex)
elif discretization.lower() == "dirac":
A_bar = torch.exp(dtA * A_complex)
else:
raise ValueError(f"Unknown discretization: {discretization}")
if discretization.lower() == "bilinear":
gamma_bar = dt / (1.0 - 0.5 * dt * A_complex) # (batch, P) complex
elif discretization.lower() == "zoh":
gamma_bar = (torch.exp(dt * A_complex) - 1.0) / A_complex
elif discretization.lower() == "dirac":
gamma_bar = torch.ones_like(dt, dtype=A.dtype)
else:
raise ValueError(f"Unknown discretization: {discretization}")
Bu = torch.einsum("ph,bh->bp", B, x.to(B.dtype))
state.copy_(A_bar * state + gamma_bar * Bu)
y_complex = torch.einsum("hp,bp->bh", C, state)
y = y_complex.real
if conj_sym:
y = 2.0 * y
if D is not None:
y = y + x @ D.T # (batch, H) @ (H, H)^T -> (batch, H)
return y