lrnnx.ops.rglru_scan module¶
RG-LRU (Recurrent Gated Linear Recurrent Unit) Scan Operation. This module exposes 2 levels of the scan similar to Mamba.
- class RGLRUScanFn[source]¶
Bases:
FunctionThin autograd wrapper around the RGLRU CUDA kernel.
All gating pre-computations must be done before calling this.
- static forward(ctx, u, delta, A, return_last_state=False)[source]¶
Forward pass for the RG-LRU Scan CUDA kernel.
- Parameters:
ctx (Any) – Autograd context.
u (torch.Tensor) – Pre-gated input of shape
(batch, dim, seqlen)in float32.delta (torch.Tensor) – Pre-computed exponent of shape
(batch, dim, seqlen)in float32.A (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim, dstate).return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
- Returns:
The output tensor, and optionally the last state.
- Return type:
- rglru_scan_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
RG-LRU scan - thin CUDA kernel wrapper.
All inputs must already be in float32.
- Parameters:
u (torch.Tensor) – Pre-gated input of shape
(batch, dim, seqlen).delta (torch.Tensor) – Pre-computed exponent of shape
(batch, dim, seqlen).A (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim, dstate).return_last_state (bool, optional) – Whether to return last hidden state. Defaults to False.
- Returns:
Output tensor of shape
(batch, dim, seqlen).last_state : If
return_last_stateis True, shape(batch, dim, dstate).
- Return type:
- rglru_scan_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
Reference RG-LRU scan (pure PyTorch, sequential loop).
- Parameters:
u (torch.Tensor) – Pre-gated input of shape
(batch, dim, seqlen)in float32.delta (torch.Tensor) – Pre-computed exponent of shape
(batch, dim, seqlen)in float32.A (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim, dstate)in float32.return_last_state (bool, optional) – Whether to return last hidden state. Defaults to False.
- Returns:
Output tensor of shape
(batch, dim, seqlen).last_state : If
return_last_stateis True, shape(batch, dim, dstate).
- Return type:
- class RGLRUInnerFn[source]¶
Bases:
FunctionRG-LRU inner function: conv1d + gate projections + gating + scan + output.
- Performs:
x = causal_conv1d(x_pre_conv) recurrent_gate = sigmoid(x @ W_r^T + b_r) input_gate = sigmoid(x @ W_i^T + b_i) delta = c x recurrent_gate u_gated = input_gate x x y = rglru_scan(u_gated, delta, a) out = (gate x y) @ W_out^T + b_out
- static forward(ctx, x, conv1d_weight, conv1d_bias, a, recurrent_gate_weight, recurrent_gate_bias, input_gate_weight, input_gate_bias, out_proj_weight, out_proj_bias, gate, c=8.0)¶
Forward pass for the RG-LRU inner function.
- Parameters:
ctx (Any) – Autograd context.
x (torch.Tensor) – Input before conv of shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weight of shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d bias of shape
(dim,)or None.a (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim,)or(dim, dstate).recurrent_gate_weight (torch.Tensor) – Recurrent gate weight of shape
(dim, dim).recurrent_gate_bias (torch.Tensor) – Recurrent gate bias of shape
(dim,).input_gate_weight (torch.Tensor) – Input gate weight of shape
(dim, dim).input_gate_bias (torch.Tensor) – Input gate bias of shape
(dim,).out_proj_weight (torch.Tensor) – Output projection weight of shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection bias of shape
(d_model,)or None.gate (torch.Tensor) – Stream-1 gate of shape
(batch, seqlen, dim).c (float, optional) – Fixed scalar constant. Defaults to 8.0.
- Returns:
The projected output tensor.
- Return type:
- static backward(ctx, dout)¶
- rglru_inner_fn(x: torch.Tensor, conv1d_weight: torch.Tensor, conv1d_bias: torch.Tensor | None, a: torch.Tensor, recurrent_gate_weight: torch.Tensor, recurrent_gate_bias: torch.Tensor, input_gate_weight: torch.Tensor, input_gate_bias: torch.Tensor, out_proj_weight: torch.Tensor, out_proj_bias: torch.Tensor | None, gate: torch.Tensor, c: float = 8.0) torch.Tensor[source]¶
RG-LRU inner function (CUDA).
Computes conv1d, gate projections, gating, scan, and output projection:
x_conv = causal_conv1d(x) recurrent_gate = sigmoid(x_conv @ W_r^T + b_r) input_gate = sigmoid(x_conv @ W_i^T + b_i) delta = c x recurrent_gate u_gated = input_gate x x_conv y = rglru_scan(u_gated, delta, a) out = (gate x y) @ W_out^T + b_out
- Parameters:
x (torch.Tensor) – Input before conv, shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weight, shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d bias, shape
(dim,)or None.a (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim,)or(dim, dstate).recurrent_gate_weight (torch.Tensor) – Recurrent gate weight, shape
(dim, dim).recurrent_gate_bias (torch.Tensor) – Recurrent gate bias, shape
(dim,).input_gate_weight (torch.Tensor) – Input gate weight, shape
(dim, dim).input_gate_bias (torch.Tensor) – Input gate bias, shape
(dim,).out_proj_weight (torch.Tensor) – Output projection weight, shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection bias, shape
(d_model,)or None.gate (torch.Tensor) – Stream-1 gate, shape
(batch, seqlen, dim).c (float, optional) – Fixed scalar constant. Defaults to 8.0.
- Returns:
Output tensor of shape
(batch, seqlen, d_model).- Return type:
- rglru_inner_ref(x: torch.Tensor, conv1d_weight: torch.Tensor, conv1d_bias: torch.Tensor | None, a: torch.Tensor, recurrent_gate_weight: torch.Tensor, recurrent_gate_bias: torch.Tensor, input_gate_weight: torch.Tensor, input_gate_bias: torch.Tensor, out_proj_weight: torch.Tensor, out_proj_bias: torch.Tensor | None, gate: torch.Tensor, c: float = 8.0) torch.Tensor[source]¶
Reference RG-LRU inner function (pure PyTorch).
Computes:
x_conv = conv1d(x)[..., :L] recurrent_gate = sigmoid(x_conv @ W_r^T + b_r) input_gate = sigmoid(x_conv @ W_i^T + b_i)
Then applies the RG-LRU scan per time‑step:
\[\begin{split}g_t &= c \cdot \operatorname{recurrent\_gate}_t \\ \bar{A}_t &= a^{\,g_t} \\ h_t &= \bar{A}_t \odot h_{t-1} + \sqrt{1 - \bar{A}_t^2} \odot (\operatorname{input\_gate}_t \odot u_t) \\ y_t &= \textstyle\sum_n h_{n,t}\end{split}\]Finally:
out = (gate * y) @ W_out^T + b_out
- Parameters:
x (torch.Tensor) – Input before conv, shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weight, shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d bias, shape
(dim,)or None.a (torch.Tensor) – Learnable recurrence base in (0, 1), shape
(dim,)or(dim, dstate).recurrent_gate_weight (torch.Tensor) – Recurrent gate weight, shape
(dim, dim).recurrent_gate_bias (torch.Tensor) – Recurrent gate bias, shape
(dim,).input_gate_weight (torch.Tensor) – Input gate weight, shape
(dim, dim).input_gate_bias (torch.Tensor) – Input gate bias, shape
(dim,).out_proj_weight (torch.Tensor) – Output projection weight, shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection bias, shape
(d_model,)or None.gate (torch.Tensor) – Stream-1 gate, shape
(batch, seqlen, dim).c (float, optional) – Fixed scalar constant. Defaults to 8.0.
- Returns:
Output tensor of shape
(batch, seqlen, d_model).- Return type: