"""
S7 Scan Operation.
This module exposes 2 levels of the scan similar to Mamba.
"""
from __future__ import annotations
import selective_scan_cuda
import torch
import torch.nn.functional as F
from einops import rearrange
from lrnnx.ops.torch import custom_bwd, custom_fwd
[docs]
class S7ScanFn(torch.autograd.Function):
"""Autograd function for S7 scan with time-varying A, B, C."""
@staticmethod
@custom_fwd
def forward(
ctx,
u,
A,
B,
C,
bias=None,
return_last_state=False,
):
"""
Forward pass for the S7 Scan CUDA kernel.
Args:
ctx (Any): Autograd context.
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)`` in float32.
A (torch.Tensor): Time-varying state transition tensor of shape ``(batch, dstate, seqlen)`` in float32.
B (torch.Tensor): Time-varying input projection tensor of shape ``(batch, dstate, dim, seqlen)`` in float32.
C (torch.Tensor): Time-varying output projection tensor of shape ``(batch, dim, dstate, seqlen)`` in float32.
bias (torch.Tensor | None, optional): Optional LTV bias tensor of shape ``(batch, dstate, seqlen)`` in float32. Defaults to None.
return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state.
"""
batch, dim, seqlen = u.shape
dstate = A.shape[1]
if not u.is_contiguous():
u = u.contiguous()
Bu = torch.einsum("bnhl,bhl->bnl", B.float(), u.float())
if bias is not None:
Bu = Bu + bias.float()
Bu = Bu.contiguous()
delta = A.contiguous()
A_kernel = torch.zeros(
dstate, 1, dtype=Bu.dtype, device=Bu.device
).contiguous()
B_kernel = torch.ones(
dstate, 1, dtype=Bu.dtype, device=Bu.device
).contiguous()
C_kernel = torch.ones(
dstate, 1, dtype=Bu.dtype, device=Bu.device
).contiguous()
out_kernel, x_kernel, *_ = selective_scan_cuda.fwd(
Bu,
delta,
A_kernel,
B_kernel,
C_kernel,
None,
None,
None,
None,
False,
"s7",
)
y = torch.einsum("bhnl,bnl->bhl", C.float(), out_kernel.float())
last_state = out_kernel[:, :, -1]
ctx.save_for_backward(
u,
B,
C,
bias,
Bu,
out_kernel,
x_kernel,
delta,
A_kernel,
B_kernel,
C_kernel,
)
if return_last_state:
return y.to(u.dtype), last_state
return y.to(u.dtype)
@staticmethod
@custom_bwd
def backward(ctx, dout):
"""
Backward pass for the S7 Scan CUDA kernel.
Args:
ctx (Any): Autograd context.
dout (torch.Tensor): Gradient of the output tensor.
Returns:
tuple: Gradients with respect to inputs ``(du, dA, dB, dC, dbias, None)``.
"""
(
u,
B,
C,
bias,
Bu,
out_kernel,
x_kernel,
delta,
A_kernel,
B_kernel,
C_kernel,
) = ctx.saved_tensors
dout = dout.float()
if dout.stride(-1) != 1:
dout = dout.contiguous()
dC = torch.einsum("bhl,bnl->bhnl", dout, out_kernel.float())
dout_kernel = torch.einsum(
"bhnl,bhl->bnl", C.float(), dout
).contiguous()
dBu, ddelta, *_ = selective_scan_cuda.bwd(
Bu.contiguous(),
delta,
A_kernel,
B_kernel,
C_kernel,
None,
None,
None,
None,
dout_kernel,
x_kernel,
out_kernel,
None,
False,
False,
"s7",
)
dA = ddelta
dB = torch.einsum("bnl,bhl->bnhl", dBu, u.float())
du = torch.einsum("bnhl,bnl->bhl", B.float(), dBu)
dbias = dBu if bias is not None else None
return (
du.to(u.dtype),
dA,
dB,
dC,
dbias,
None, # return_last_state
)
[docs]
def s7_scan_fn(
u: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
bias: torch.Tensor | None = None,
return_last_state: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
S7 scan using CUDA kernel.
Args:
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)`` in float32.
A (torch.Tensor): Time-varying state transition tensor of shape ``(batch, dstate, seqlen)`` in float32.
B (torch.Tensor): Time-varying input projection tensor of shape ``(batch, dstate, dim, seqlen)`` in float32.
C (torch.Tensor): Time-varying output projection tensor of shape ``(batch, dim, dstate, seqlen)`` in float32.
bias (torch.Tensor | None, optional): Optional LTV bias tensor of shape ``(batch, dstate, seqlen)`` in float32. Defaults to None.
return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- out : Output tensor of shape ``(batch, dim, seqlen)`` in float32.
- last_state : If ``return_last_state=True``, returns the last state of shape ``(batch, dstate)`` in float32.
"""
return S7ScanFn.apply(
u,
A,
B,
C,
bias,
return_last_state,
)
[docs]
def s7_scan_ref(
u: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
bias: torch.Tensor | None = None,
return_last_state: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Reference implementation of S7 scan (pure PyTorch).
Args:
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)`` in float32.
A (torch.Tensor): Time-varying state transition tensor of shape ``(batch, dstate, seqlen)`` in float32.
B (torch.Tensor): Time-varying projection tensor of shape ``(batch, dstate, dim, seqlen)`` in float32.
C (torch.Tensor): Time-varying projection tensor of shape ``(batch, dim, dstate, seqlen)`` in float32.
bias (torch.Tensor | None, optional): Optional LTV bias tensor of shape ``(batch, dstate, seqlen)`` in float32. Defaults to None.
return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- out : Output tensor of shape ``(batch, dim, seqlen)`` in float32.
- last_state : If ``return_last_state=True``, returns the last state of shape ``(batch, dstate)`` in float32.
"""
dtype_in = u.dtype
batch, dim, seqlen = u.shape
dstate = A.shape[1]
u = u.float()
A = A.float()
B = B.float()
C = C.float()
A_sq_half = A * A + 0.5
A_bar = 1.0 - 1.0 / A_sq_half # (batch, dstate, seqlen)
Bu = torch.einsum("bnhl,bhl->bnl", B, u)
if bias is not None:
Bu = Bu + bias
x = torch.zeros((batch, dstate), dtype=torch.float32, device=u.device)
xs = []
for t in range(seqlen):
x = A_bar[:, :, t] * x + Bu[:, :, t]
xs.append(x)
last_state = x
x_seq = torch.stack(xs, dim=-1)
y = torch.einsum("bhnl,bnl->bhl", C, x_seq)
out = y.to(dtype_in)
if return_last_state:
return out, last_state
return out
[docs]
class S7InnerFn(torch.autograd.Function):
"""Fused S7 inner function with custom forward and backward."""
@staticmethod
@custom_fwd
def forward(
ctx,
hidden_states,
in_proj_weight,
x_proj_weight,
gate_proj_weight,
d_state,
base_params,
):
"""
Forward pass for the fused S7 inner function.
Args:
ctx (Any): Autograd context.
hidden_states (torch.Tensor): Input hidden states of shape ``(batch, seqlen, d_model)``.
in_proj_weight (torch.Tensor): Input projection weight matrix.
x_proj_weight (torch.Tensor): X projection weight matrix.
gate_proj_weight (torch.Tensor): Gate projection weight matrix.
d_state (int): State dimension.
base_params (torch.Tensor): Base HiPPO initialization parameters.
Returns:
torch.Tensor: The projected and gated output tensor of shape ``(batch, seqlen, d_model)``.
"""
batch, seqlen, d_model = hidden_states.shape
if hidden_states.stride(-1) != 1:
hidden_states = hidden_states.contiguous()
# in_proj: (B, L, D) @ (D, D)^T -> (B, L, D)
x = F.linear(hidden_states, in_proj_weight)
# x_proj: (BL, D) @ (D, N + 2*D*N + D)^T -> (BL, N + 2*D*N + D)
x_dbl = F.linear(rearrange(x, "b l d -> (b l) d"), x_proj_weight)
# Split into A, B, C, D, bias
A, B, C, D, bias = torch.split(
x_dbl,
[d_state, d_model * d_state, d_model * d_state, d_model, d_state],
dim=-1,
)
A = rearrange(A, "(b l) n -> b n l", l=seqlen)
A = A + base_params.unsqueeze(0).unsqueeze(
-1
) # Add HiPPO initialization to A
B = rearrange(B, "(b l) (h n) -> b n h l", l=seqlen, n=d_state)
C = rearrange(
C, "(b l) (h n) -> b h n l", l=seqlen, n=d_state
).contiguous()
D_tv = rearrange(D, "(b l) h -> b h l", l=seqlen).contiguous()
bias = rearrange(bias, "(b l) n -> b n l", l=seqlen)
u = rearrange(x, "b l h -> b h l")
Bu = torch.einsum("bnhl,bhl->bnl", B.float(), u.float())
if bias is not None:
Bu = Bu + bias.float()
Bu = Bu.contiguous()
# Run CUDA kernel with dim=N, dstate=1
delta = A.contiguous()
A_kernel = torch.zeros(d_state, 1, dtype=Bu.dtype, device=Bu.device)
B_kernel = torch.ones(
batch, 1, 1, seqlen, dtype=Bu.dtype, device=Bu.device
).contiguous()
C_kernel = torch.ones(
batch, 1, 1, seqlen, dtype=Bu.dtype, device=Bu.device
).contiguous()
out_kernel, x_kernel, *_ = selective_scan_cuda.fwd(
Bu,
delta,
A_kernel,
B_kernel,
C_kernel,
None,
None,
None,
None,
False,
"s7",
)
# Project output: y = C @ out_kernel
y = torch.einsum("bhnl,bnl->bhl", C.float(), out_kernel.float())
# Apply D (time-varying skip)
y = y + D_tv.float() * u.float()
# Gating: gate_proj -> sigmoid, then gate * y
y_t = rearrange(y, "b h l -> b l h")
gelu_y_t = F.gelu(y_t)
gate = torch.sigmoid(F.linear(gelu_y_t, gate_proj_weight))
y_gated = gate * y_t
# Residual
out = y_gated + hidden_states
ctx.save_for_backward(
hidden_states,
in_proj_weight,
x_proj_weight,
gate_proj_weight,
x,
x_dbl,
A,
B,
C,
D_tv,
u,
Bu,
out_kernel,
x_kernel,
delta,
A_kernel,
B_kernel,
C_kernel,
y,
y_t,
gate,
gelu_y_t,
)
ctx.d_state = d_state
return out.to(hidden_states.dtype)
@staticmethod
@custom_bwd
def backward(ctx, dout):
"""
Backward pass for the fused S7 inner function.
Args:
ctx (Any): Autograd context.
dout (torch.Tensor): Gradient of the output tensor.
Returns:
tuple: Gradients with respect to inputs.
"""
(
hidden_states,
in_proj_weight,
x_proj_weight,
gate_proj_weight,
x,
x_dbl,
A,
B,
C,
D_tv,
u,
Bu,
out_kernel,
x_kernel,
delta,
A_kernel,
B_kernel,
C_kernel,
y,
y_t,
gate,
gelu_y_t,
) = ctx.saved_tensors
d_state = ctx.d_state
batch, seqlen, d_model = hidden_states.shape
dout = dout.float()
if dout.stride(-1) != 1:
dout = dout.contiguous()
dy_gated = dout
dhidden_states = dout.clone()
# Gradient through gating: y_gated = gate * y_t
dgate = dy_gated * y_t
dy_t = dy_gated * gate
# Gradient through gate = sigmoid(gate_proj(gelu(y_t)))
dsigmoid = gate * (1.0 - gate)
dgate_pre = dgate * dsigmoid
# Gradient through gate_proj
dgate_proj_weight = dgate_pre.reshape(
-1, d_model
).t() @ gelu_y_t.reshape(-1, d_model)
dgelu_y_t = F.linear(dgate_pre, gate_proj_weight.t())
# Gradient through GELU: gelu'(x) = 0.5 * (1 + tanh(k)) + 0.5 * x * sech^2(k) * (sqrt(2/pi) * (1 + 3*0.044715*x^2))
# where k = sqrt(2/pi) * (x + 0.044715 * x^3)
# Simplified: use the fact that gelu(x) ≈ x * sigmoid(1.702 * x) for approximation
# But for exact gradient, compute directly
k = 0.7978845608 * (
y_t + 0.044715 * y_t * y_t * y_t
) # sqrt(2/pi) ≈ 0.7978845608
tanh_k = torch.tanh(k)
sech2_k = 1.0 - tanh_k * tanh_k
dgelu = 0.5 * (1.0 + tanh_k) + 0.5 * y_t * sech2_k * 0.7978845608 * (
1.0 + 3.0 * 0.044715 * y_t * y_t
)
dy_t_from_gate = dgelu_y_t * dgelu + dy_t
dy = rearrange(dy_t_from_gate, "b l h -> b h l")
# Gradient through D skip: y = y_inner + D_tv * u
dD_tv = dy * u.float()
dy_inner = dy
du_from_D = D_tv.float() * dy
# Gradient through C projection: y_inner = C @ out_kernel
dC = torch.einsum("bhl,bnl->bhnl", dy_inner, out_kernel.float())
dout_kernel = torch.einsum(
"bhnl,bhl->bnl", C.float(), dy_inner
).contiguous()
# Backward through CUDA kernel
dBu, ddelta, *_ = selective_scan_cuda.bwd(
Bu.contiguous(),
delta,
A_kernel,
B_kernel,
C_kernel,
None,
None,
None,
None,
dout_kernel,
x_kernel,
out_kernel,
None,
False,
False,
"s7",
)
dA = ddelta
dbase_params = torch.sum(
dA, dim=(0, 2)
) # Sum over batch and time dimensions for base_params gradient
dB = torch.einsum("bnl,bhl->bnhl", dBu, u.float())
dbias = dBu
du_from_B = torch.einsum("bnhl,bnl->bhl", B.float(), dBu)
du = du_from_B + du_from_D
dx = rearrange(du, "b h l -> b l h")
# Reconstruct x_dbl gradients
dA_flat = rearrange(dA, "b n l -> (b l) n")
dB_flat = rearrange(dB, "b n h l -> (b l) (h n)")
dC_flat = rearrange(dC, "b h n l -> (b l) (h n)")
dD_flat = rearrange(dD_tv, "b h l -> (b l) h")
dbias_flat = rearrange(dbias, "b n l -> (b l) n")
dx_dbl = torch.cat(
[dA_flat, dB_flat, dC_flat, dD_flat, dbias_flat], dim=-1
)
# Gradient through x_proj
dx_proj_weight = dx_dbl.t() @ rearrange(x, "b l d -> (b l) d")
dx_from_proj = F.linear(dx_dbl, x_proj_weight.t())
dx_from_proj = rearrange(dx_from_proj, "(b l) d -> b l d", l=seqlen)
dx = dx + dx_from_proj
# Gradient through in_proj (no GELU here)
din_proj_weight = dx.reshape(-1, d_model).t() @ hidden_states.reshape(
-1, d_model
)
dhidden_states = dhidden_states + F.linear(dx, in_proj_weight.t())
return (
dhidden_states.to(hidden_states.dtype),
din_proj_weight,
dx_proj_weight,
dgate_proj_weight,
None, # d_state
dbase_params,
)
[docs]
def s7_inner_fn(
hidden_states: torch.Tensor,
in_proj_weight: torch.Tensor,
x_proj_weight: torch.Tensor,
gate_proj_weight: torch.Tensor,
d_state: int,
base_params: torch.Tensor,
) -> torch.Tensor:
"""
Fused S7 inner function using CUDA kernel.
Args:
hidden_states (torch.Tensor): Input hidden states of shape ``(batch, seqlen, d_model)``.
in_proj_weight (torch.Tensor): Input projection weight matrix.
x_proj_weight (torch.Tensor): X projection weight matrix.
gate_proj_weight (torch.Tensor): Gate projection weight matrix.
d_state (int): State dimension.
base_params (torch.Tensor): Base HiPPO initialization parameters.
Returns:
torch.Tensor: Output tensor of shape ``(batch, seqlen, d_model)``.
"""
return S7InnerFn.apply(
hidden_states,
in_proj_weight,
x_proj_weight,
gate_proj_weight,
d_state,
base_params,
)
[docs]
def s7_inner_ref(
hidden_states: torch.Tensor,
in_proj_weight: torch.Tensor,
x_proj_weight: torch.Tensor,
gate_proj_weight: torch.Tensor,
d_state: int,
base_params: torch.Tensor,
) -> torch.Tensor:
"""
Reference S7 inner function (pure PyTorch).
Args:
hidden_states (torch.Tensor): Input hidden states of shape ``(batch, seqlen, d_model)``.
in_proj_weight (torch.Tensor): Input projection weight matrix.
x_proj_weight (torch.Tensor): X projection weight matrix.
gate_proj_weight (torch.Tensor): Gate projection weight matrix.
d_state (int): State dimension.
base_params (torch.Tensor): Base HiPPO initialization parameters.
Returns:
torch.Tensor: Output tensor of shape ``(batch, seqlen, d_model)``.
"""
batch, seqlen, d_model = hidden_states.shape
x = F.linear(hidden_states, in_proj_weight)
x_dbl = F.linear(rearrange(x, "b l d -> (b l) d"), x_proj_weight)
A, B, C, D, bias = torch.split(
x_dbl,
[d_state, d_model * d_state, d_model * d_state, d_model, d_state],
dim=-1,
)
A = rearrange(A, "(b l) n -> b n l", l=seqlen) + base_params.unsqueeze(
0
).unsqueeze(
-1
) # Add HiPPO initialization to A
B = rearrange(B, "(b l) (h n) -> b n h l", l=seqlen, n=d_state)
C = rearrange(C, "(b l) (h n) -> b h n l", l=seqlen, n=d_state)
D_tv = rearrange(D, "(b l) h -> b h l", l=seqlen)
bias = rearrange(bias, "(b l) n -> b n l", l=seqlen)
u = rearrange(x, "b l h -> b h l")
y = s7_scan_ref(u, A, B, C, bias=bias)
y = y + D_tv * u
y_t = rearrange(y, "b h l -> b l h")
gate = torch.sigmoid(F.linear(F.gelu(y_t), gate_proj_weight))
y_gated = gate * y_t
out = y_gated + hidden_states
return out