lrnnx.models.ltv.rglru module

RG-LRU (Recurrent Gated Linear Recurrent Unit) block. https://arxiv.org/abs/2402.19427

class RGLRU[source]

Bases: LTV_LRNN

RG-LRU block following the Griffin architecture.

Example

>>> model = RGLRU(d_model=64, d_state=1, d_conv=4)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_conv: int = 4, expand: int = 1, c: float = 8.0, a_init_range: Tuple[float, float] = (0.9, 0.999), conv_bias: bool = True, bias: bool = False, use_fast_path: bool = True, layer_idx: int | None = None, device=None, dtype=None)[source]

Initialize RG-LRU block.

Parameters:
  • d_model (int) – Model dimension.

  • d_conv (int, optional) – Temporal convolution kernel size. Defaults to 4.

  • expand (int, optional) – Expansion factor for inner dimension. Defaults to 1.

  • c (float, optional) – Fixed scalar for recurrent gate scaling. Defaults to 8.0.

  • a_init_range (Tuple[float, float], optional) – Tuple (lo, hi) so a is initialised in [lo, hi] in (0, 1). Defaults to (0.9, 0.999).

  • conv_bias (bool, optional) – Whether the Conv1D uses a bias term. Defaults to True.

  • bias (bool, optional) – Whether Linear projections use bias. Defaults to False.

  • use_fast_path (bool, optional) – Use the fused CUDA kernel when available. Defaults to True.

  • layer_idx (int, optional) – Layer index (for multi-layer caching). Defaults to None.

  • device (torch.device, optional) – Device for parameters. Defaults to None.

  • dtype (torch.dtype, optional) – Data type for parameters. Defaults to None.

forward(hidden_states: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None, inference_cache: Dict[str, Any] | None = None) torch.Tensor[source]

Forward pass through the RG-LRU block.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor of shape (B, L, D).

  • integration_timesteps (torch.Tensor, optional) – Unused - kept for LTV interface compat. Defaults to None.

  • lengths (torch.Tensor, optional) – Unused - kept for interface compatibility. Defaults to None.

  • inference_cache (Dict[str, Any], optional) – Cache dict for autoregressive generation. Defaults to None.

Returns:

Output tensor of shape (B, L, D).

Return type:

torch.Tensor

step(hidden_states: torch.Tensor, inference_cache: Dict[str, Any]) Tuple[torch.Tensor, Dict[str, Any]][source]

Single recurrent step for autoregressive inference.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor of shape (B, 1, D).

  • inference_cache (Dict[str, Any]) – Must contain conv_state, lrnn_state, and seqlen_offset.

Returns:

Tuple containing:
  • out : Output tensor of shape (B, 1, D).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int, dtype: torch.dtype | None = None) Dict[str, Any][source]

Allocate cache for autoregressive inference.

Parameters:
  • batch_size (int) – Batch size.

  • max_seqlen (int) – Unused, kept for interface consistency.

  • dtype (torch.dtype, optional) – Data type for cache tensors. Defaults to None.

Returns:

Cache dictionary containing “conv_state”, “ssm_state”, and “seqlen_offset”.

Return type:

Dict[str, Any]