Source code for lrnnx.models.lti.base

"""
Base class for LTI models.
"""

from abc import abstractmethod
from typing import Any, Dict, Literal, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.backends import opt_einsum

from lrnnx.core.base import LRNN


[docs] class LTI_LRNN(LRNN): """ Base class for all LTI LRNN models. Note: LTI models do not support async discretization as that requires time-varying dynamics. For async/event-driven models, use LTV models. Example: >>> from lrnnx.models.lti import LTI_LRNN >>> my_lrnn = LTI_LRNN("zoh") >>> # create dummy input tensor and perform forward pass >>> # in subclass """
[docs] def __init__( self, discretization: Literal[ "zoh", "bilinear", "dirac", "no_discretization" ], ): """ Initialize the LTI LRNN base class. Args: discretization (Literal["zoh", "bilinear", "dirac", "no_discretization"]): Discretization method to use. """ # for optimal contractions assert opt_einsum.is_available() opt_einsum.strategy = "optimal" super().__init__(discretization=discretization)
[docs] @abstractmethod def discretize(self) -> tuple[Tensor, Union[Tensor, float], Tensor]: """ This function discretizes the A, B and C matrices, with a learned step-size delta. This could be done inside the `compute_kernel` method itself, but doing this explicitly outside allows for more flexibility later. Returns: tuple[torch.Tensor, torch.Tensor | float, torch.Tensor]: A tuple of tensors representing the discretized A, B, C matrices, ideally of shapes (B, N), (B, N, H) or float, and (B, H, N) respectively. """ raise NotImplementedError( "discretize method must be implemented in the subclass." )
[docs] @abstractmethod def compute_kernel(self, *args, **kwargs) -> tuple[Tensor, Tensor]: """ Computes the convolution kernel for efficient parallel processing. This function is only relevant for LTI models; for LTV models this will materialize a huge vector in-memory at every timestep, which is not efficient. Reference: https://github.com/kunibald413/aTENNuate/blob/15a27dab00d3bf2c27cbbbc3bd41a3d9196dca1e/attenuate/model.py#L30 Args: *args: Model-specific arguments (e.g., sequence length, discretized matrices). See subclass implementations for details. **kwargs: Additional model-specific keyword arguments. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - K : Powers of A matrix (A^0, A^1, ..., A^{L-1}), shape (N, L) - B_norm : Normalized input projection matrix, shape (N, H) """ raise NotImplementedError( "compute_kernel method must be implemented in the subclass." )
[docs] @abstractmethod def step( self, x: torch.Tensor, inference_cache: Dict[str, Any], **kwargs, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Performs a single recurrent step of the LTI model. This method is used for autoregressive inference, where inputs are processed one timestep at a time. Args: x (torch.Tensor): Input at current timestep, shape (B, H). inference_cache (Dict[str, Any]): Cache dictionary from allocate_inference_cache() containing recurrent state and pre-computed matrices. Updated in-place and returned. **kwargs: Additional keyword arguments. Returns: tuple[torch.Tensor, Dict[str, Any]]: A tuple containing: - y : Output at current timestep, shape (B, H). - inference_cache : Updated cache dictionary. """ raise NotImplementedError( "step method must be implemented in the subclass." )
[docs] @abstractmethod def allocate_inference_cache( self, batch_size: int, max_seqlen: int = 1, dtype: Optional[torch.dtype] = None, **kwargs, ) -> Dict[str, Any]: """ Allocates initial state and caches matrices for efficient inference. For LTI models, the system matrices (A, B, C) are time-invariant, so they can be pre-computed once and reused for all timesteps during autoregressive generation. Args: batch_size (int): The batch size for inference. max_seqlen (int, optional): Maximum sequence length (unused for LTI, kept for interface consistency with LTV models). Defaults to 1. dtype (torch.dtype, optional): Data type for allocated tensors. Defaults to None. **kwargs: Additional model-specific arguments. Returns: Dict[str, Any]: Cache dictionary containing initial state and pre-computed matrices for use in step(). """ raise NotImplementedError( "allocate_inference_cache method must be implemented in the subclass." )