"""
Implementation of Linear Recurrent Unit (LRU) layer.
Paper: https://arxiv.org/abs/2303.06349.
"""
import math
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from lrnnx.core.convolution import opt_ssm_forward
from lrnnx.models.lti.base import LTI_LRNN
[docs]
class LRU(LTI_LRNN):
"""
Linear Recurrent Unit (LRU) layer.
Paper: https://arxiv.org/abs/2303.06349
Example:
>>> model = LRU(d_model=64, d_state=64)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
"""
[docs]
def __init__(
self,
d_model: int,
d_state: int,
r_min: float = 0,
r_max: float = 1,
max_phase: float = 2 * math.pi,
) -> None:
"""
Initialize LRU layer.
Args:
d_model (int): Model dimension.
d_state (int): State dimension.
r_min (float, optional): Minimum radius for Lambda initialization. Defaults to 0.
r_max (float, optional): Maximum radius for Lambda initialization. Defaults to 1.
max_phase (float, optional): Maximum phase for Lambda initialization.
Defaults to ``2 * math.pi``.
"""
super().__init__(
discretization="no_discretization"
) # discretization is unused in LRU
self.d_model = d_model
self.d_state = d_state
self._init_parameters(d_model, d_state, r_min, r_max, max_phase)
def _init_parameters(
self,
d_model: int,
d_state: int,
r_min: float = 0,
r_max: float = 1,
max_phase: float = 2 * math.pi,
) -> None:
"""
Initialize parameters of the LRU layer.
Args:
d_model (int): Model dimension.
d_state (int): State dimension.
r_min (float, optional): Minimum radius for Lambda initialization. Defaults to 0.
r_max (float, optional): Maximum radius for Lambda initialization. Defaults to 1.
max_phase (float, optional): Maximum phase for Lambda initialization. Defaults to 2 * math.pi.
"""
u1 = torch.rand(d_state)
u2 = torch.rand(d_state)
# nu_log and theta_log are used to init Lambda
# values distributed uniformly on ring b/w r_min and r_max
nu_log = torch.log(
-0.5 * torch.log(u1 * (r_max**2 - r_min**2) + r_min**2)
)
# phase b/w 0 and max_phase
theta_log = torch.log(max_phase * u2)
# glorot-initialized input/output projection matrices
B_re = torch.randn(d_state, d_model) / (2 * d_model) ** 0.5
B_im = torch.randn(d_state, d_model) / (2 * d_model) ** 0.5
C_re = torch.randn(d_model, d_state) / d_state**0.5
C_im = torch.randn(d_model, d_state) / d_state**0.5
D = torch.randn(d_model)
# normalization factor
diag_lambda = torch.exp(-torch.exp(nu_log) + 1j * torch.exp(theta_log))
# the original paper reports that setting gamma_log to
# torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2))
# after every step also yields similar results to making it learnable
gamma_log = torch.log(torch.sqrt(1 - torch.abs(diag_lambda) ** 2))
# register parameters
self.nu_log = nn.Parameter(nu_log)
self.theta_log = nn.Parameter(theta_log)
self.B_re = nn.Parameter(B_re)
self.B_im = nn.Parameter(B_im)
self.C_re = nn.Parameter(C_re)
self.C_im = nn.Parameter(C_im)
self.D = nn.Parameter(D)
self.gamma_log = nn.Parameter(gamma_log)
[docs]
def discretize(self) -> tuple[Tensor, Tensor, Tensor]:
"""
LRU uses no_discretization, so this acts like a prepare matrices method.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
- A : Diagonal matrix of Lambda values, shape ``(N, N)``.
- B : Complex input projection matrix, shape ``(N, H)``.
- C : Complex output projection matrix, shape ``(H, N)``.
"""
Lambda = torch.exp(
-torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
) # (N,)
C_complex = self.C_re + 1j * self.C_im # (H, N)
B_complex = self.B_re + 1j * self.B_im # (N, H)
A = Lambda # (N,)
B = B_complex # (N, H)
C = C_complex # (H, N)
return A, B, C
[docs]
def compute_kernel(
self, L: int, Lambda: Tensor, B_complex: Tensor
) -> tuple[Tensor, Tensor]:
"""
Compute Lambda and normalized B matrix for LRU.
Args:
L (int): Length of the input sequence.
Lambda (torch.Tensor): Complex eigenvalues/diagonal elements, shape ``(N,)``.
B_complex (torch.Tensor): Complex input projection matrix, shape ``(N, H)``.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- Lambda : Complex eigenvalues/diagonal elements, shape ``(N,)``.
- B_norm : Normalized complex input projection matrix, shape ``(N, H)``.
"""
t_idx = torch.arange(
L, dtype=torch.float32, device=Lambda.device
) # (L,)
K = Lambda.unsqueeze(-1) ** t_idx.unsqueeze(0) # (N, L)
B_norm = B_complex * torch.exp(self.gamma_log).unsqueeze(-1) # (N, H)
return K, B_norm
[docs]
def forward(
self,
x: Tensor,
integration_timesteps: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
) -> Tensor:
"""
Forward pass of the LRU layer.
Args:
x (torch.Tensor): Input tensor of shape ``(B, L, H)``.
integration_timesteps (torch.Tensor, optional): <To be implemented>. Defaults to None.
lengths (torch.Tensor, optional): <To be implemented>. Defaults to None.
Returns:
torch.Tensor: Output tensor of shape ``(B, L, H)``.
"""
if x.dim() != 3:
raise ValueError(
f"Input tensor must be of shape (B, L, H), got {x.dim()}D tensor with shape {x.shape}"
)
L = x.shape[1]
# prepare matrices
Lambda, B_complex, C_complex = self.discretize()
# compute kernel
K, B_norm = self.compute_kernel(L, Lambda, B_complex)
# convolve over input
y_conv = opt_ssm_forward(x, K, B_norm, C_complex)
# skip connection
y = y_conv + x * self.D # (B, L, H)
return y
[docs]
def step(
self,
x: Tensor,
inference_cache: Dict[str, Any],
**kwargs,
) -> tuple[Tensor, Dict[str, Any]]:
"""
Single step inference for LRU layer.
Args:
x (torch.Tensor): Input tensor of shape ``(B, H)`` - single timestep.
inference_cache (Dict[str, Any]): Cache from ``allocate_inference_cache()``
containing "lrnn_state" and pre-computed matrices.
**kwargs: Additional keyword arguments.
Returns:
tuple[torch.Tensor, Dict[str, Any]]: A tuple containing:
- y : Output tensor of shape ``(B, H)``.
- inference_cache : Updated cache dictionary.
"""
if x.dim() != 2:
raise ValueError(
f"Input tensor must be of shape (B, H), got {x.dim()}D tensor with shape {x.shape}"
)
state = inference_cache["lrnn_state"]
# Extract cached matrices
Lambda = inference_cache["Lambda"]
B_norm = inference_cache["B_norm"]
C_complex = inference_cache["C_complex"]
# Recurrent update: x_t -> state_{t+1}
# state_{t+1} = Lambda * state_t + B_norm @ u_t
input_projection = torch.einsum(
"nh,bh->bn", B_norm, x.to(B_norm.dtype)
) # (B, N)
new_state = Lambda * state + input_projection # (B, N)
# Output computation: y_t = C @ state_t + D * x_t
state_output = torch.einsum(
"hn,bn->bh", C_complex, new_state
).real # (B, H)
y = state_output + x * self.D # (B, H)
inference_cache["lrnn_state"].copy_(new_state)
return y, inference_cache
[docs]
def allocate_inference_cache(
self,
batch_size: int,
max_seqlen: int = 1,
dtype=None,
**kwargs,
) -> Dict[str, Any]:
"""
Allocate initial state and cached matrices for inference.
Args:
batch_size (int): Batch size.
max_seqlen (int, optional): Maximum sequence length (unused, kept for
interface consistency with LTV models). Defaults to 1.
dtype (torch.dtype, optional): Data type for allocated tensors (unused). Defaults to None.
**kwargs: Additional model-specific arguments.
Returns:
Dict[str, Any]: Cache dict with "lrnn_state" and
pre-computed discrete matrices.
"""
# Initialize state to zeros
device = self.nu_log.device
initial_state = torch.zeros(
batch_size, self.d_state, dtype=torch.complex64, device=device
)
# Pre-compute and cache matrices (LTI - compute once)
Lambda, B_complex, C_complex = self.discretize()
B_norm = B_complex * torch.exp(self.gamma_log).unsqueeze(-1)
return {
"lrnn_state": initial_state,
"Lambda": Lambda,
"B_norm": B_norm,
"C_complex": C_complex,
}