"""
RG-LRU (Recurrent Gated Linear Recurrent Unit) Scan Operation.
This module exposes 2 levels of the scan similar to Mamba.
"""
from __future__ import annotations
import selective_scan_cuda
import torch
import torch.nn.functional as F
from einops import rearrange
try:
from causal_conv1d import causal_conv1d_fn
from causal_conv1d.cpp_functions import (
causal_conv1d_bwd_function,
causal_conv1d_fwd_function,
)
except ImportError:
causal_conv1d_fn = None
causal_conv1d_fwd_function = None
causal_conv1d_bwd_function = None
from lrnnx.ops.torch import custom_bwd, custom_fwd
[docs]
class RGLRUScanFn(torch.autograd.Function):
"""
Thin autograd wrapper around the RGLRU CUDA kernel.
All gating pre-computations must be done *before* calling this.
"""
[docs]
@staticmethod
def forward(
ctx,
u,
delta,
A,
return_last_state=False,
):
"""
Forward pass for the RG-LRU Scan CUDA kernel.
Args:
ctx (Any): Autograd context.
u (torch.Tensor): Pre-gated input of shape ``(batch, dim, seqlen)`` in float32.
delta (torch.Tensor): Pre-computed exponent of shape ``(batch, dim, seqlen)`` in float32.
A (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim, dstate)``.
return_last_state (bool, optional): Whether to return the last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state.
"""
if not u.is_contiguous():
u = u.contiguous()
if not delta.is_contiguous():
delta = delta.contiguous()
if not A.is_contiguous():
A = A.contiguous()
dim, dstate = A.shape
# Identity B / C (time-invariant, not learnable)
B = torch.ones(
dim, dstate, dtype=u.dtype, device=u.device
).contiguous()
C = torch.ones(
dim, dstate, dtype=u.dtype, device=u.device
).contiguous()
out, x, *_ = selective_scan_cuda.fwd(
u,
delta,
A,
B,
C,
None, # D
None, # z
None, # delta_bias
None, # deltaA
False, # delta_softplus
"rglru",
)
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
ctx.save_for_backward(u, delta, A, B, C, out, x)
if return_last_state:
return out, last_state
return out
[docs]
@staticmethod
def backward(ctx, dout, *args):
u, delta, A, B, C, out, x = ctx.saved_tensors
if not dout.is_contiguous():
dout = dout.contiguous()
du, ddelta, dA, *_ = selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
None, # D
None, # z
None, # delta_bias
None, # deltaA
dout,
x,
out,
None, # dz
False, # delta_softplus
False, # recompute_out
"rglru",
)
return (
du,
ddelta,
dA,
None, # return_last_state
)
[docs]
def rglru_scan_fn(
u: torch.Tensor,
delta: torch.Tensor,
A: torch.Tensor,
return_last_state: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
RG-LRU scan - thin CUDA kernel wrapper.
All inputs must already be in float32.
Args:
u (torch.Tensor): Pre-gated input of shape ``(batch, dim, seqlen)``.
delta (torch.Tensor): Pre-computed exponent of shape ``(batch, dim, seqlen)``.
A (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim, dstate)``.
return_last_state (bool, optional): Whether to return last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- Output tensor of shape ``(batch, dim, seqlen)``.
- last_state : If ``return_last_state`` is True, shape ``(batch, dim, dstate)``.
"""
return RGLRUScanFn.apply(u, delta, A, return_last_state)
[docs]
def rglru_scan_ref(
u: torch.Tensor,
delta: torch.Tensor,
A: torch.Tensor,
return_last_state: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Reference RG-LRU scan (pure PyTorch, sequential loop).
Args:
u (torch.Tensor): Pre-gated input of shape ``(batch, dim, seqlen)`` in float32.
delta (torch.Tensor): Pre-computed exponent of shape ``(batch, dim, seqlen)`` in float32.
A (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim, dstate)`` in float32.
return_last_state (bool, optional): Whether to return last hidden state. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- Output tensor of shape ``(batch, dim, seqlen)``.
- last_state : If ``return_last_state`` is True, shape ``(batch, dim, dstate)``.
"""
dtype_in = u.dtype
batch, dim, seqlen = u.shape
dstate = A.shape[1]
u = u.float()
delta = delta.float()
A = A.float()
# A_bar = A ^ delta via log-space for stability
log_A = torch.log(A) # (dim, dstate)
# delta: (B, D, L) -> (B, D, 1, L); log_A: (D, N) -> (1, D, N, 1)
A_bar = torch.exp(
delta.unsqueeze(2) * log_A.unsqueeze(0).unsqueeze(-1)
) # (B, D, N, L)
sqrt_term = torch.sqrt(1.0 - A_bar * A_bar) # (B, D, N, L)
# B_u = sqrt(1 - A_bar^2) * u
B_u = sqrt_term * u.unsqueeze(2) # (B, D, N, L)
# Sequential scan
h = torch.zeros(batch, dim, dstate, dtype=torch.float32, device=u.device)
hs = []
for t in range(seqlen):
h = A_bar[:, :, :, t] * h + B_u[:, :, :, t]
hs.append(h)
last_state = h
h_seq = torch.stack(hs, dim=-1) # (B, D, N, L)
# C = identity ones → y = sum over dstate
y = h_seq.sum(dim=2) # (B, D, L)
out = y.to(dtype_in)
if return_last_state:
return out, last_state
return out
[docs]
class RGLRUInnerFn(torch.autograd.Function):
"""
RG-LRU inner function: conv1d + gate projections + gating + scan + output.
Performs:
x = causal_conv1d(x_pre_conv)
recurrent_gate = sigmoid(x @ W_r^T + b_r)
input_gate = sigmoid(x @ W_i^T + b_i)
delta = c x recurrent_gate
u_gated = input_gate x x
y = rglru_scan(u_gated, delta, a)
out = (gate x y) @ W_out^T + b_out
"""
@staticmethod
@custom_fwd
def forward(
ctx,
x,
conv1d_weight,
conv1d_bias,
a,
recurrent_gate_weight,
recurrent_gate_bias,
input_gate_weight,
input_gate_bias,
out_proj_weight,
out_proj_bias,
gate,
c=8.0,
):
"""
Forward pass for the RG-LRU inner function.
Args:
ctx (Any): Autograd context.
x (torch.Tensor): Input before conv of shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weight of shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d bias of shape ``(dim,)`` or None.
a (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim,)`` or ``(dim, dstate)``.
recurrent_gate_weight (torch.Tensor): Recurrent gate weight of shape ``(dim, dim)``.
recurrent_gate_bias (torch.Tensor): Recurrent gate bias of shape ``(dim,)``.
input_gate_weight (torch.Tensor): Input gate weight of shape ``(dim, dim)``.
input_gate_bias (torch.Tensor): Input gate bias of shape ``(dim,)``.
out_proj_weight (torch.Tensor): Output projection weight of shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection bias of shape ``(d_model,)`` or None.
gate (torch.Tensor): Stream-1 gate of shape ``(batch, seqlen, dim)``.
c (float, optional): Fixed scalar constant. Defaults to 8.0.
Returns:
torch.Tensor: The projected output tensor.
"""
assert (
causal_conv1d_fn is not None
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
dtype_in = x.dtype
batch = x.shape[0]
L = x.shape[-1]
a_was_1d = a.dim() == 1
if a_was_1d:
a = a.unsqueeze(-1) # (dim,) -> (dim, 1)
if x.stride(-1) != 1:
x = x.contiguous()
a_f = a.float().contiguous()
# Causal conv1d
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
conv1d_bias = (
conv1d_bias.contiguous() if conv1d_bias is not None else None
)
conv1d_out = causal_conv1d_fwd_function(
x, conv1d_weight, conv1d_bias, None, None, None, False
)
x_f = conv1d_out.float()
# Gate projections: conv1d_out (B, D, L) → (BL, D) → linear → sigmoid
x_flat = rearrange(x_f, "b d l -> (b l) d") # (BL, D)
rg_pre = F.linear(
x_flat,
recurrent_gate_weight.float(),
recurrent_gate_bias.float(),
)
recurrent_gate = torch.sigmoid(rg_pre) # (BL, D)
ig_pre = F.linear(
x_flat,
input_gate_weight.float(),
input_gate_bias.float(),
)
input_gate = torch.sigmoid(ig_pre) # (BL, D)
recurrent_gate_bdl = rearrange(
recurrent_gate, "(b l) d -> b d l", b=batch
).contiguous()
input_gate_bdl = rearrange(
input_gate, "(b l) d -> b d l", b=batch
).contiguous()
delta = (c * recurrent_gate_bdl).contiguous() # (B, D, L)
u_gated = (input_gate_bdl * x_f).contiguous() # (B, D, L)
dim, dstate = a_f.shape
# Identity B / C placeholders
B = torch.ones(dim, dstate, dtype=torch.float32, device=x.device)
C = torch.ones(dim, dstate, dtype=torch.float32, device=x.device)
out, x_states, *_ = selective_scan_cuda.fwd(
u_gated,
delta,
a_f,
B,
C,
None, # D
None, # z
None, # delta_bias
None, # deltaA
False, # delta_softplus
"rglru",
)
# Merge with stream-1 gate and project out
y = rearrange(out, "b d l -> b l d") # (B, L, D)
gate_f = gate.float()
result = F.linear(
gate_f * y,
out_proj_weight.float(),
out_proj_bias.float() if out_proj_bias is not None else None,
)
ctx.save_for_backward(
x, # pre-conv input for conv1d backward
conv1d_weight,
conv1d_bias,
conv1d_out,
a_f,
recurrent_gate,
input_gate, # (BL, D) each
u_gated,
delta,
out,
x_states,
B,
C,
out_proj_weight.float(),
gate_f,
recurrent_gate_weight.float(),
input_gate_weight.float(),
)
ctx.c = c
ctx.a_was_1d = a_was_1d
ctx.dtype_in = dtype_in
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.batch = batch
ctx.L = L
return result.to(dtype_in)
@staticmethod
@custom_bwd
def backward(ctx, dout, *args):
(
x_pre_conv,
conv1d_weight,
conv1d_bias,
conv1d_out,
a_f,
recurrent_gate,
input_gate,
u_gated,
delta,
scan_out,
x_states,
B,
C,
out_proj_weight,
gate_f,
rg_weight,
ig_weight,
) = ctx.saved_tensors
c = ctx.c
dtype_in = ctx.dtype_in
batch = ctx.batch
L = ctx.L
dout = dout.float()
if not dout.is_contiguous():
dout = dout.contiguous()
x_f = conv1d_out.float()
y = rearrange(scan_out, "b d l -> b l d") # (B, L, D)
gate_y = gate_f * y # (B, L, D)
dout_2d = rearrange(dout, "b l e -> (b l) e") # (BL, D_model)
gate_y_2d = rearrange(gate_y, "b l d -> (b l) d") # (BL, D)
d_gate_y_2d = dout_2d @ out_proj_weight # (BL, D)
d_out_proj_weight = dout_2d.t() @ gate_y_2d # (D_model, D)
d_out_proj_bias = (
dout_2d.sum(0) if not ctx.out_proj_bias_is_None else None
)
# Backward through gate * y
d_gate_y = rearrange(d_gate_y_2d, "(b l) d -> b l d", b=batch)
d_gate = d_gate_y * y # (B, L, D)
dy = d_gate_y * gate_f # (B, L, D)
# Backward through rearrange → CUDA scan backward
dout_scan = rearrange(dy, "b l d -> b d l").contiguous()
du_gated, ddelta, dA, *_ = selective_scan_cuda.bwd(
u_gated,
delta,
a_f,
B,
C,
None, # D
None, # z
None, # delta_bias
None, # deltaA
dout_scan,
x_states,
scan_out,
None, # dz
False, # delta_softplus
False, # recompute_out
"rglru",
)
# Chain rule: u_gated = input_gate_bdl * x_f
input_gate_bdl = rearrange(input_gate, "(b l) d -> b d l", b=batch)
dconv1d_out = du_gated * input_gate_bdl # (B, D, L)
d_input_gate_bdl = du_gated * x_f # (B, D, L)
# Chain rule: delta = c * recurrent_gate_bdl
d_recurrent_gate_bdl = ddelta * c # (B, D, L)
# Reshape to (BL, D) for sigmoid + linear backward
d_input_gate_2d = rearrange(d_input_gate_bdl, "b d l -> (b l) d")
d_recurrent_gate_2d = rearrange(
d_recurrent_gate_bdl, "b d l -> (b l) d"
)
# Backward through sigmoid
d_ig_pre = d_input_gate_2d * input_gate * (1 - input_gate)
d_rg_pre = d_recurrent_gate_2d * recurrent_gate * (1 - recurrent_gate)
# Backward through linear projections
x_flat = rearrange(x_f, "b d l -> (b l) d")
d_ig_weight = d_ig_pre.t() @ x_flat # (D, D)
d_ig_bias = d_ig_pre.sum(0) # (D,)
dconv1d_out_from_ig = d_ig_pre @ ig_weight # (BL, D)
d_rg_weight = d_rg_pre.t() @ x_flat # (D, D)
d_rg_bias = d_rg_pre.sum(0) # (D,)
dconv1d_out_from_rg = d_rg_pre @ rg_weight # (BL, D)
# Total dconv1d_out
dconv1d_out_from_proj = rearrange(
dconv1d_out_from_ig + dconv1d_out_from_rg,
"(b l) d -> b d l",
b=batch,
)
dconv1d_out = dconv1d_out + dconv1d_out_from_proj # (B, D, L)
# Backward through causal conv1d
dx = torch.empty_like(x_pre_conv)
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(
x_pre_conv,
conv1d_weight,
conv1d_bias,
dconv1d_out.to(x_pre_conv.dtype),
None,
None,
None,
dx,
False,
False,
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
da = dA
if ctx.a_was_1d:
da = da.squeeze(-1)
return (
dx.to(dtype_in), # x
dconv1d_weight.to(dtype_in), # conv1d_weight
(
dconv1d_bias.to(dtype_in) if dconv1d_bias is not None else None
), # conv1d_bias
da, # a
d_rg_weight.to(dtype_in), # recurrent_gate_weight
d_rg_bias.to(dtype_in), # recurrent_gate_bias
d_ig_weight.to(dtype_in), # input_gate_weight
d_ig_bias.to(dtype_in), # input_gate_bias
d_out_proj_weight.to(dtype_in), # out_proj_weight
(
d_out_proj_bias.to(dtype_in)
if d_out_proj_bias is not None
else None
), # out_proj_bias
d_gate.to(dtype_in), # gate
None, # c
)
[docs]
def rglru_inner_fn(
x: torch.Tensor,
conv1d_weight: torch.Tensor,
conv1d_bias: torch.Tensor | None,
a: torch.Tensor,
recurrent_gate_weight: torch.Tensor,
recurrent_gate_bias: torch.Tensor,
input_gate_weight: torch.Tensor,
input_gate_bias: torch.Tensor,
out_proj_weight: torch.Tensor,
out_proj_bias: torch.Tensor | None,
gate: torch.Tensor,
c: float = 8.0,
) -> torch.Tensor:
r"""
RG-LRU inner function (CUDA).
Computes conv1d, gate projections, gating, scan, and output projection::
x_conv = causal_conv1d(x)
recurrent_gate = sigmoid(x_conv @ W_r^T + b_r)
input_gate = sigmoid(x_conv @ W_i^T + b_i)
delta = c x recurrent_gate
u_gated = input_gate x x_conv
y = rglru_scan(u_gated, delta, a)
out = (gate x y) @ W_out^T + b_out
Args:
x (torch.Tensor): Input before conv, shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weight, shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d bias, shape ``(dim,)`` or None.
a (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim,)`` or ``(dim, dstate)``.
recurrent_gate_weight (torch.Tensor): Recurrent gate weight, shape ``(dim, dim)``.
recurrent_gate_bias (torch.Tensor): Recurrent gate bias, shape ``(dim,)``.
input_gate_weight (torch.Tensor): Input gate weight, shape ``(dim, dim)``.
input_gate_bias (torch.Tensor): Input gate bias, shape ``(dim,)``.
out_proj_weight (torch.Tensor): Output projection weight, shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection bias, shape ``(d_model,)`` or None.
gate (torch.Tensor): Stream-1 gate, shape ``(batch, seqlen, dim)``.
c (float, optional): Fixed scalar constant. Defaults to 8.0.
Returns:
torch.Tensor: Output tensor of shape ``(batch, seqlen, d_model)``.
"""
return RGLRUInnerFn.apply(
x,
conv1d_weight,
conv1d_bias,
a,
recurrent_gate_weight,
recurrent_gate_bias,
input_gate_weight,
input_gate_bias,
out_proj_weight,
out_proj_bias,
gate,
c,
)
[docs]
def rglru_inner_ref(
x: torch.Tensor,
conv1d_weight: torch.Tensor,
conv1d_bias: torch.Tensor | None,
a: torch.Tensor,
recurrent_gate_weight: torch.Tensor,
recurrent_gate_bias: torch.Tensor,
input_gate_weight: torch.Tensor,
input_gate_bias: torch.Tensor,
out_proj_weight: torch.Tensor,
out_proj_bias: torch.Tensor | None,
gate: torch.Tensor,
c: float = 8.0,
) -> torch.Tensor:
r"""
Reference RG-LRU inner function (pure PyTorch).
Computes::
x_conv = conv1d(x)[..., :L]
recurrent_gate = sigmoid(x_conv @ W_r^T + b_r)
input_gate = sigmoid(x_conv @ W_i^T + b_i)
Then applies the RG-LRU scan per time‑step:
.. math::
g_t &= c \cdot \operatorname{recurrent\_gate}_t \\
\bar{A}_t &= a^{\,g_t} \\
h_t &= \bar{A}_t \odot h_{t-1}
+ \sqrt{1 - \bar{A}_t^2} \odot
(\operatorname{input\_gate}_t \odot u_t) \\
y_t &= \textstyle\sum_n h_{n,t}
Finally::
out = (gate * y) @ W_out^T + b_out
Args:
x (torch.Tensor): Input before conv, shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weight, shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d bias, shape ``(dim,)`` or None.
a (torch.Tensor): Learnable recurrence base in (0, 1), shape ``(dim,)`` or ``(dim, dstate)``.
recurrent_gate_weight (torch.Tensor): Recurrent gate weight, shape ``(dim, dim)``.
recurrent_gate_bias (torch.Tensor): Recurrent gate bias, shape ``(dim,)``.
input_gate_weight (torch.Tensor): Input gate weight, shape ``(dim, dim)``.
input_gate_bias (torch.Tensor): Input gate bias, shape ``(dim,)``.
out_proj_weight (torch.Tensor): Output projection weight, shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection bias, shape ``(d_model,)`` or None.
gate (torch.Tensor): Stream-1 gate, shape ``(batch, seqlen, dim)``.
c (float, optional): Fixed scalar constant. Defaults to 8.0.
Returns:
torch.Tensor: Output tensor of shape ``(batch, seqlen, d_model)``.
"""
dtype_in = x.dtype
L = x.shape[-1]
x = x.float()
a = a.float()
gate = gate.float()
if a.dim() == 1:
a = a.unsqueeze(-1) # (dim,) -> (dim, 1)
# Conv1d (depthwise, causal via padding + truncation)
conv1d_weight_f = conv1d_weight.float()
d_conv = conv1d_weight.shape[-1]
x_padded = F.pad(x, (d_conv - 1, 0))
x_conv = F.conv1d(
x_padded,
conv1d_weight_f,
conv1d_bias.float() if conv1d_bias is not None else None,
groups=x.shape[1],
)
# Gate projections
batch = x.shape[0]
x_flat = rearrange(x_conv, "b d l -> (b l) d")
recurrent_gate = torch.sigmoid(
F.linear(
x_flat,
recurrent_gate_weight.float(),
recurrent_gate_bias.float(),
)
)
input_gate = torch.sigmoid(
F.linear(
x_flat,
input_gate_weight.float(),
input_gate_bias.float(),
)
)
recurrent_gate = rearrange(
recurrent_gate, "(b l) d -> b d l", b=batch
).contiguous()
input_gate = rearrange(
input_gate, "(b l) d -> b d l", b=batch
).contiguous()
# Projections
delta = c * recurrent_gate # (B, D, L)
u_gated = input_gate * x_conv # (B, D, L)
result = rglru_scan_ref(u_gated, delta, a)
# Merge with stream-1 gate and project out
y = rearrange(result, "b d l -> b l d")
out = F.linear(
gate * y,
out_proj_weight.float(),
out_proj_bias.float() if out_proj_bias is not None else None,
)
return out.to(dtype_in)