lrnnx.models.lti.lru module¶
Implementation of Linear Recurrent Unit (LRU) layer. Paper: https://arxiv.org/abs/2303.06349.
- class LRU[source]¶
Bases:
LTI_LRNNLinear Recurrent Unit (LRU) layer.
Paper: https://arxiv.org/abs/2303.06349
Example
>>> model = LRU(d_model=64, d_state=64) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, r_min: float = 0, r_max: float = 1, max_phase: float = 6.283185307179586) None[source]¶
Initialize LRU layer.
- Parameters:
d_model (int) – Model dimension.
d_state (int) – State dimension.
r_min (float, optional) – Minimum radius for Lambda initialization. Defaults to 0.
r_max (float, optional) – Maximum radius for Lambda initialization. Defaults to 1.
max_phase (float, optional) – Maximum phase for Lambda initialization. Defaults to
2 * math.pi.
- discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
LRU uses no_discretization, so this acts like a prepare matrices method.
- Returns:
- A tuple containing:
A : Diagonal matrix of Lambda values, shape
(N, N).B : Complex input projection matrix, shape
(N, H).C : Complex output projection matrix, shape
(H, N).
- Return type:
- compute_kernel(L: int, Lambda: torch.Tensor, B_complex: torch.Tensor) tuple[torch.Tensor, torch.Tensor][source]¶
Compute Lambda and normalized B matrix for LRU.
- Parameters:
L (int) – Length of the input sequence.
Lambda (torch.Tensor) – Complex eigenvalues/diagonal elements, shape
(N,).B_complex (torch.Tensor) – Complex input projection matrix, shape
(N, H).
- Returns:
- A tuple containing:
Lambda : Complex eigenvalues/diagonal elements, shape
(N,).B_norm : Normalized complex input projection matrix, shape
(N, H).
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass of the LRU layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, L, H).integration_timesteps (torch.Tensor, optional) – <To be implemented>. Defaults to None.
lengths (torch.Tensor, optional) – <To be implemented>. Defaults to None.
- Returns:
Output tensor of shape
(B, L, H).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Single step inference for LRU layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(B, H)- single timestep.inference_cache (Dict[str, Any]) – Cache from
allocate_inference_cache()containing “lrnn_state” and pre-computed matrices.
- Returns:
- A tuple containing:
y : Output tensor of shape
(B, H).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1, dtype=None) Dict[str, Any][source]¶
Allocate initial state and cached matrices for inference.
- Parameters:
batch_size (int) – Batch size.
max_seqlen (int, optional) – Maximum sequence length (unused, kept for interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional) – Data type for allocated tensors (unused). Defaults to None.
- Returns:
- Cache dict with “lrnn_state” and
pre-computed discrete matrices.
- Return type:
Dict[str, Any]