"""
RG-LRU (Recurrent Gated Linear Recurrent Unit) block.
https://arxiv.org/abs/2402.19427
"""
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from lrnnx.models.ltv.base import LTV_LRNN
from lrnnx.ops.rglru_scan import rglru_inner_fn, rglru_scan_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None # type: ignore[assignment]
try:
from lrnnx.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None # type: ignore[assignment]
[docs]
class RGLRU(LTV_LRNN):
"""
RG-LRU block following the Griffin architecture.
Example:
>>> model = RGLRU(d_model=64, d_state=1, d_conv=4)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
"""
[docs]
def __init__(
self,
d_model: int,
d_conv: int = 4,
expand: int = 1,
c: float = 8.0,
a_init_range: Tuple[float, float] = (0.9, 0.999),
conv_bias: bool = True,
bias: bool = False,
use_fast_path: bool = True,
layer_idx: Optional[int] = None,
device=None,
dtype=None,
):
"""
Initialize RG-LRU block.
Args:
d_model (int): Model dimension.
d_conv (int, optional): Temporal convolution kernel size. Defaults to 4.
expand (int, optional): Expansion factor for inner dimension. Defaults to 1.
c (float, optional): Fixed scalar for recurrent gate scaling. Defaults to 8.0.
a_init_range (Tuple[float, float], optional): Tuple ``(lo, hi)`` so *a* is
initialised in ``[lo, hi]`` in ``(0, 1)``. Defaults to ``(0.9, 0.999)``.
conv_bias (bool, optional): Whether the Conv1D uses a bias term. Defaults to True.
bias (bool, optional): Whether Linear projections use bias. Defaults to False.
use_fast_path (bool, optional): Use the fused CUDA kernel when available. Defaults to True.
layer_idx (int, optional): Layer index (for multi-layer caching). Defaults to None.
device (torch.device, optional): Device for parameters. Defaults to None.
dtype (torch.dtype, optional): Data type for parameters. Defaults to None.
"""
# RG-LRU handles discretisation internally
super().__init__(discretization=None)
factory_kwargs = {"device": device, "dtype": dtype}
self.d_model = d_model
self.d_conv = d_conv
self.expand = expand
self.dstate = 1
self.d_inner = int(self.expand * self.d_model)
self.c = c
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
# Stream 1: Linear -> GeLU
self.gate_proj = nn.Linear(
self.d_model, self.d_inner, bias=bias, **factory_kwargs
)
# Stream 2: Linear -> Conv1D -> RG-LRU
self.in_proj = nn.Linear(
self.d_model, self.d_inner, bias=bias, **factory_kwargs
)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
bias=conv_bias,
**factory_kwargs,
)
# Recurrent / input gate projections
self.recurrent_gate_proj = nn.Linear(
self.d_inner, self.d_inner, bias=True, **factory_kwargs
)
self.input_gate_proj = nn.Linear(
self.d_inner, self.d_inner, bias=True, **factory_kwargs
)
# Learnable recurrence base a in (0, 1), shape (d_inner, d_state)
a_lo, a_hi = a_init_range
a_init = a_lo + (a_hi - a_lo) * torch.rand(
self.d_inner, self.dstate, **factory_kwargs
)
self.a_log = nn.Parameter(torch.log(a_init))
self.a_log._no_weight_decay = True # type: ignore[attr-defined]
# Output projection
self.out_proj = nn.Linear(
self.d_inner, self.d_model, bias=bias, **factory_kwargs
)
[docs]
def forward(
self,
hidden_states: Tensor,
integration_timesteps: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
inference_cache: Optional[Dict[str, Any]] = None,
) -> Tensor:
"""
Forward pass through the RG-LRU block.
Args:
hidden_states (torch.Tensor): Input tensor of shape ``(B, L, D)``.
integration_timesteps (torch.Tensor, optional): *Unused* - kept for LTV interface compat. Defaults to None.
lengths (torch.Tensor, optional): *Unused* - kept for interface compatibility. Defaults to None.
inference_cache (Dict[str, Any], optional): Cache dict for autoregressive generation. Defaults to None.
Returns:
torch.Tensor: Output tensor of shape ``(B, L, D)``.
"""
batch, seqlen, dim = hidden_states.shape
if inference_cache is not None:
seqlen_offset = inference_cache.get("seqlen_offset", 0)
if seqlen_offset > 0:
out, inference_cache = self.step(
hidden_states, inference_cache
)
return out
# Stream 1: gate path
gate = F.gelu(self.gate_proj(hidden_states)) # (B, L, D_inner)
# Stream 2: conv -> RG-LRU
x = self.in_proj(hidden_states) # (B, L, D_inner)
x = rearrange(x, "b l d -> b d l")
# Learnable base in (0, 1)
a = torch.sigmoid(self.a_log) # (d_inner, d_state)
if (
self.use_fast_path
and causal_conv1d_fn is not None
and inference_cache is None
):
out = rglru_inner_fn(
x,
self.conv1d.weight,
self.conv1d.bias,
a,
self.recurrent_gate_proj.weight,
self.recurrent_gate_proj.bias,
self.input_gate_proj.weight,
self.input_gate_proj.bias,
self.out_proj.weight,
self.out_proj.bias,
gate,
c=self.c,
)
else:
# Causal temporal convolution
if causal_conv1d_fn is not None:
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=None,
)
else:
x = self.conv1d(x)[..., :seqlen]
# Update conv cache if present
conv_state = (
inference_cache.get("conv_state") if inference_cache else None
)
if conv_state is not None:
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
# Gate projections (B, D_inner, L) -> transpose -> project -> back
x_BLD = rearrange(x, "b d l -> (b l) d")
recurrent_gate = torch.sigmoid(self.recurrent_gate_proj(x_BLD))
input_gate = torch.sigmoid(self.input_gate_proj(x_BLD))
recurrent_gate = rearrange(
recurrent_gate, "(b l) d -> b d l", l=seqlen
).contiguous()
input_gate = rearrange(
input_gate, "(b l) d -> b d l", l=seqlen
).contiguous()
ssm_state = (
inference_cache.get("lrnn_state") if inference_cache else None
)
# Manual path: gating + scan
delta = (self.c * recurrent_gate).float().contiguous()
u_gated = (input_gate * x).float().contiguous()
y = rglru_scan_fn(
u_gated,
delta,
a,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y # type: ignore[assignment]
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d") # type: ignore[assignment]
# Merge streams and project out
out = self.out_proj(gate * y)
return out
[docs]
def step( # type: ignore[override]
self,
hidden_states: Tensor,
inference_cache: Dict[str, Any],
**kwargs,
) -> Tuple[Tensor, Dict[str, Any]]:
"""
Single recurrent step for autoregressive inference.
Args:
hidden_states (torch.Tensor): Input tensor of shape ``(B, 1, D)``.
inference_cache (Dict[str, Any]): Must contain conv_state, lrnn_state,
and seqlen_offset.
**kwargs: Additional keyword arguments.
Returns:
tuple[torch.Tensor, Dict[str, Any]]: Tuple containing:
- out : Output tensor of shape ``(B, 1, D)``.
- inference_cache : Updated cache dictionary.
"""
conv_state = inference_cache["conv_state"]
ssm_state = inference_cache["lrnn_state"]
dtype = hidden_states.dtype
assert (
hidden_states.shape[1] == 1
), "step() supports single-token decoding only"
x_in = hidden_states.squeeze(1) # (B, D)
# Stream 1
gate = F.gelu(self.gate_proj(x_in)) # (B, D_inner)
# Stream 2
x = self.in_proj(x_in) # (B, D_inner)
# Conv step
if causal_conv1d_update is not None:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
activation=None,
)
else:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x
x = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"),
dim=-1,
)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
# Gate projections
recurrent_gate = torch.sigmoid(
self.recurrent_gate_proj(x)
) # (B, D_inner)
input_gate = torch.sigmoid(self.input_gate_proj(x)) # (B, D_inner)
a = torch.sigmoid(self.a_log) # (d_inner, d_state)
# Pre-compute gate and gated input for the RG-LRU recurrence
gate_val = self.c * recurrent_gate # (B, D_inner)
u_gated = input_gate * x # (B, D_inner)
if selective_state_update is not None:
# Triton fused path: pass pre-computed gate via dt, gated input via x,
# A (base in (0,1)) and identity B/C.
B_ones = torch.ones(
u_gated.shape[0],
self.dstate,
device=u_gated.device,
dtype=u_gated.dtype,
)
C_ones = torch.ones_like(B_ones)
y = selective_state_update(
ssm_state,
u_gated,
gate_val,
a,
B_ones,
C_ones,
dt_bias=None,
dt_softplus=False,
discretization="rglru",
)
else:
# fallback
# a: (D, N), gate_val: (B, D) -> a_bar: (B, D, N)
a_bar = a.unsqueeze(0).pow(gate_val.unsqueeze(-1)) # (B, D, N)
sqrt_term = torch.sqrt(1.0 - a_bar * a_bar)
new_state = a_bar * ssm_state + sqrt_term * u_gated.unsqueeze(-1)
y = new_state.sum(dim=-1) # (B, D_inner) - sum over dstate
ssm_state.copy_(new_state)
# Merge and project out
out = self.out_proj(gate * y)
inference_cache["conv_state"] = conv_state
inference_cache["lrnn_state"] = ssm_state
inference_cache["seqlen_offset"] = (
inference_cache.get("seqlen_offset", 0) + 1
)
return out.unsqueeze(1).to(dtype), inference_cache
[docs]
def allocate_inference_cache(
self,
batch_size: int,
max_seqlen: int,
dtype: Optional[torch.dtype] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Allocate cache for autoregressive inference.
Args:
batch_size (int): Batch size.
max_seqlen (int): Unused, kept for interface consistency.
dtype (torch.dtype, optional): Data type for cache tensors. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Dict[str, Any]: Cache dictionary containing "conv_state", "ssm_state", and "seqlen_offset".
"""
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size,
self.d_inner,
self.d_conv,
device=device,
dtype=conv_dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_inner,
self.dstate,
device=device,
dtype=ssm_dtype,
)
return {
"conv_state": conv_state,
"lrnn_state": ssm_state,
"seqlen_offset": 0,
}