lrnnx.ops.rglru_scan module

RG-LRU (Recurrent Gated Linear Recurrent Unit) Scan Operation. This module exposes 2 levels of the scan similar to Mamba.

class RGLRUScanFn[source]

Bases: Function

Thin autograd wrapper around the RGLRU CUDA kernel.

All gating pre-computations must be done before calling this.

static forward(ctx, u, delta, A, return_last_state=False)[source]

Forward pass for the RG-LRU Scan CUDA kernel.

Parameters:
  • ctx (Any) – Autograd context.

  • u (torch.Tensor) – Pre-gated input of shape (batch, dim, seqlen) in float32.

  • delta (torch.Tensor) – Pre-computed exponent of shape (batch, dim, seqlen) in float32.

  • A (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim, dstate).

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

Returns:

The output tensor, and optionally the last state.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

static backward(ctx, dout)[source]
rglru_scan_fn(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

RG-LRU scan - thin CUDA kernel wrapper.

All inputs must already be in float32.

Parameters:
  • u (torch.Tensor) – Pre-gated input of shape (batch, dim, seqlen).

  • delta (torch.Tensor) – Pre-computed exponent of shape (batch, dim, seqlen).

  • A (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim, dstate).

  • return_last_state (bool, optional) – Whether to return last hidden state. Defaults to False.

Returns:

  • Output tensor of shape (batch, dim, seqlen).

  • last_state : If return_last_state is True, shape (batch, dim, dstate).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

rglru_scan_ref(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

Reference RG-LRU scan (pure PyTorch, sequential loop).

Parameters:
  • u (torch.Tensor) – Pre-gated input of shape (batch, dim, seqlen) in float32.

  • delta (torch.Tensor) – Pre-computed exponent of shape (batch, dim, seqlen) in float32.

  • A (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim, dstate) in float32.

  • return_last_state (bool, optional) – Whether to return last hidden state. Defaults to False.

Returns:

  • Output tensor of shape (batch, dim, seqlen).

  • last_state : If return_last_state is True, shape (batch, dim, dstate).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class RGLRUInnerFn[source]

Bases: Function

RG-LRU inner function: conv1d + gate projections + gating + scan + output.

Performs:

x = causal_conv1d(x_pre_conv) recurrent_gate = sigmoid(x @ W_r^T + b_r) input_gate = sigmoid(x @ W_i^T + b_i) delta = c x recurrent_gate u_gated = input_gate x x y = rglru_scan(u_gated, delta, a) out = (gate x y) @ W_out^T + b_out

static forward(ctx, x, conv1d_weight, conv1d_bias, a, recurrent_gate_weight, recurrent_gate_bias, input_gate_weight, input_gate_bias, out_proj_weight, out_proj_bias, gate, c=8.0)

Forward pass for the RG-LRU inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • x (torch.Tensor) – Input before conv of shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weight of shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d bias of shape (dim,) or None.

  • a (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim,) or (dim, dstate).

  • recurrent_gate_weight (torch.Tensor) – Recurrent gate weight of shape (dim, dim).

  • recurrent_gate_bias (torch.Tensor) – Recurrent gate bias of shape (dim,).

  • input_gate_weight (torch.Tensor) – Input gate weight of shape (dim, dim).

  • input_gate_bias (torch.Tensor) – Input gate bias of shape (dim,).

  • out_proj_weight (torch.Tensor) – Output projection weight of shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection bias of shape (d_model,) or None.

  • gate (torch.Tensor) – Stream-1 gate of shape (batch, seqlen, dim).

  • c (float, optional) – Fixed scalar constant. Defaults to 8.0.

Returns:

The projected output tensor.

Return type:

torch.Tensor

static backward(ctx, dout)
rglru_inner_fn(x: torch.Tensor, conv1d_weight: torch.Tensor, conv1d_bias: torch.Tensor | None, a: torch.Tensor, recurrent_gate_weight: torch.Tensor, recurrent_gate_bias: torch.Tensor, input_gate_weight: torch.Tensor, input_gate_bias: torch.Tensor, out_proj_weight: torch.Tensor, out_proj_bias: torch.Tensor | None, gate: torch.Tensor, c: float = 8.0) torch.Tensor[source]

RG-LRU inner function (CUDA).

Computes conv1d, gate projections, gating, scan, and output projection:

x_conv         = causal_conv1d(x)
recurrent_gate = sigmoid(x_conv @ W_r^T + b_r)
input_gate     = sigmoid(x_conv @ W_i^T + b_i)
delta          = c x recurrent_gate
u_gated        = input_gate x x_conv
y              = rglru_scan(u_gated, delta, a)
out            = (gate x y) @ W_out^T + b_out
Parameters:
  • x (torch.Tensor) – Input before conv, shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weight, shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d bias, shape (dim,) or None.

  • a (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim,) or (dim, dstate).

  • recurrent_gate_weight (torch.Tensor) – Recurrent gate weight, shape (dim, dim).

  • recurrent_gate_bias (torch.Tensor) – Recurrent gate bias, shape (dim,).

  • input_gate_weight (torch.Tensor) – Input gate weight, shape (dim, dim).

  • input_gate_bias (torch.Tensor) – Input gate bias, shape (dim,).

  • out_proj_weight (torch.Tensor) – Output projection weight, shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection bias, shape (d_model,) or None.

  • gate (torch.Tensor) – Stream-1 gate, shape (batch, seqlen, dim).

  • c (float, optional) – Fixed scalar constant. Defaults to 8.0.

Returns:

Output tensor of shape (batch, seqlen, d_model).

Return type:

torch.Tensor

rglru_inner_ref(x: torch.Tensor, conv1d_weight: torch.Tensor, conv1d_bias: torch.Tensor | None, a: torch.Tensor, recurrent_gate_weight: torch.Tensor, recurrent_gate_bias: torch.Tensor, input_gate_weight: torch.Tensor, input_gate_bias: torch.Tensor, out_proj_weight: torch.Tensor, out_proj_bias: torch.Tensor | None, gate: torch.Tensor, c: float = 8.0) torch.Tensor[source]

Reference RG-LRU inner function (pure PyTorch).

Computes:

x_conv         = conv1d(x)[..., :L]
recurrent_gate = sigmoid(x_conv @ W_r^T + b_r)
input_gate     = sigmoid(x_conv @ W_i^T + b_i)

Then applies the RG-LRU scan per time‑step:

\[\begin{split}g_t &= c \cdot \operatorname{recurrent\_gate}_t \\ \bar{A}_t &= a^{\,g_t} \\ h_t &= \bar{A}_t \odot h_{t-1} + \sqrt{1 - \bar{A}_t^2} \odot (\operatorname{input\_gate}_t \odot u_t) \\ y_t &= \textstyle\sum_n h_{n,t}\end{split}\]

Finally:

out = (gate * y) @ W_out^T + b_out
Parameters:
  • x (torch.Tensor) – Input before conv, shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weight, shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d bias, shape (dim,) or None.

  • a (torch.Tensor) – Learnable recurrence base in (0, 1), shape (dim,) or (dim, dstate).

  • recurrent_gate_weight (torch.Tensor) – Recurrent gate weight, shape (dim, dim).

  • recurrent_gate_bias (torch.Tensor) – Recurrent gate bias, shape (dim,).

  • input_gate_weight (torch.Tensor) – Input gate weight, shape (dim, dim).

  • input_gate_bias (torch.Tensor) – Input gate bias, shape (dim,).

  • out_proj_weight (torch.Tensor) – Output projection weight, shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection bias, shape (d_model,) or None.

  • gate (torch.Tensor) – Stream-1 gate, shape (batch, seqlen, dim).

  • c (float, optional) – Fixed scalar constant. Defaults to 8.0.

Returns:

Output tensor of shape (batch, seqlen, d_model).

Return type:

torch.Tensor