lrnnx.models.lti.lru module

Implementation of Linear Recurrent Unit (LRU) layer. Paper: https://arxiv.org/abs/2303.06349.

class LRU[source]

Bases: LTI_LRNN

Linear Recurrent Unit (LRU) layer.

Paper: https://arxiv.org/abs/2303.06349

Example

>>> model = LRU(d_model=64, d_state=64)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, r_min: float = 0, r_max: float = 1, max_phase: float = 6.283185307179586) None[source]

Initialize LRU layer.

Parameters:
  • d_model (int) – Model dimension.

  • d_state (int) – State dimension.

  • r_min (float, optional) – Minimum radius for Lambda initialization. Defaults to 0.

  • r_max (float, optional) – Maximum radius for Lambda initialization. Defaults to 1.

  • max_phase (float, optional) – Maximum phase for Lambda initialization. Defaults to 2 * math.pi.

discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

LRU uses no_discretization, so this acts like a prepare matrices method.

Returns:

A tuple containing:
  • A : Diagonal matrix of Lambda values, shape (N, N).

  • B : Complex input projection matrix, shape (N, H).

  • C : Complex output projection matrix, shape (H, N).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

compute_kernel(L: int, Lambda: torch.Tensor, B_complex: torch.Tensor) tuple[torch.Tensor, torch.Tensor][source]

Compute Lambda and normalized B matrix for LRU.

Parameters:
  • L (int) – Length of the input sequence.

  • Lambda (torch.Tensor) – Complex eigenvalues/diagonal elements, shape (N,).

  • B_complex (torch.Tensor) – Complex input projection matrix, shape (N, H).

Returns:

A tuple containing:
  • Lambda : Complex eigenvalues/diagonal elements, shape (N,).

  • B_norm : Normalized complex input projection matrix, shape (N, H).

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass of the LRU layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, L, H).

  • integration_timesteps (torch.Tensor, optional) – <To be implemented>. Defaults to None.

  • lengths (torch.Tensor, optional) – <To be implemented>. Defaults to None.

Returns:

Output tensor of shape (B, L, H).

Return type:

torch.Tensor

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Single step inference for LRU layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H) - single timestep.

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache() containing “lrnn_state” and pre-computed matrices.

Returns:

A tuple containing:
  • y : Output tensor of shape (B, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]

Allocate initial state and cached matrices for inference.

Parameters:
  • batch_size (int) – Batch size.

  • max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.

  • dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.

Returns:

Cache dict with “lrnn_state” and

pre-computed discrete matrices.

Return type:

Dict[str, Any]