Source code for lrnnx.core.base
"""
Base Model class for LRNNX.
"""
from abc import abstractmethod
from typing import Literal, Optional
from torch import Tensor
from torch.nn import Module
from .discretization import DISCRETIZE_FNS
[docs]
class LRNN(Module):
[docs]
def __init__(
self,
discretization: Optional[
Literal["zoh", "bilinear", "dirac", "async", "no_discretization"]
],
):
"""
Initialize the base LRNN model.
Args:
discretization (Optional[Literal]): Discretization method to use, can be one of:
- "zoh" for Zero-Order Hold
- "bilinear" for Bilinear method
- "dirac" for Dirac method
- "async" for asynchronous discretization
- "no_discretization" for no discretization
- None for models that handle discretization internally
Each model must have a usage example in the documentation, like so:
>>> from lrnnx.core import LRNN
>>> my_lrnn = LRNN("zoh")
>>> # create dummy input tensor and perform forward pass
>>> # in subclass
"""
super().__init__()
if discretization is not None:
assert (
discretization in DISCRETIZE_FNS
), f"Discretization method {discretization} is not supported. Choose from {list(DISCRETIZE_FNS.keys())}."
self.discretize_fn = DISCRETIZE_FNS[discretization]
else:
self.discretize_fn = None # type: ignore[assignment]
[docs]
@abstractmethod
def forward(
self,
x: Tensor,
integration_timesteps: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
) -> Tensor:
"""
Forward pass of through the LRNN.
Args:
x (torch.Tensor): Input tensor, ideally of shape ``(B, L, H)``.
integration_timesteps (torch.Tensor, optional): Timesteps for async/event-driven
discretization (Reference: https://arxiv.org/abs/2404.18508),
ideally of shape ``(B, L)``. Only applicable for LTV models;
LTI models ignore this parameter. Defaults to None.
lengths (torch.Tensor, optional): Lengths of sequences, ideally of shape ``(B,)``,
this is required for bidirectional models. Defaults to None.
Returns:
torch.Tensor: Output tensor, same shape as input (x), ideally ``(B, L, H)``.
"""
raise NotImplementedError(
"forward method must be implemented in the subclass."
)