Source code for lrnnx.core.discretization

"""
Discretization methods for continuous-time systems.
"""

from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor


[docs] def zoh( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None ) -> tuple[Tensor, Tensor]: """ Zero-Order Hold (ZOH) discretization method, used across most models. .. math:: \\bar{A} &= \\exp(\\Delta A) \\\\ \\bar{\\gamma} &= A^{-1} (\\bar{A} - I) Reference: https://hazyresearch.stanford.edu/blog/2022-01-14-s4-3 Args: A (torch.Tensor): The continuous-time state matrix. delta (torch.Tensor): The discretization step size. integration_timesteps (torch.Tensor, optional): Not used in ZOH discretization. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - A_bar : The discretized system matrix. - gamma_bar : The input normalizer. """ Identity = torch.ones(A.shape[0], device=A.device) A_bar = torch.exp(delta * A) gamma_bar = (1 / A) * (A_bar - Identity) return A_bar, gamma_bar
[docs] def bilinear( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None ) -> tuple[Tensor, Tensor]: """ Bilinear method first used in S4. .. math:: \\bar{A} &= (I + 0.5 \\Delta A)^{-1} (I - 0.5 \\Delta A) \\\\ \\bar{\\gamma} &= (I + 0.5 \\Delta A)^{-1} \\Delta Reference: https://hazyresearch.stanford.edu/blog/2022-01-14-s4-3 Args: A (torch.Tensor): Continuous-time system matrix, shape: (N,), i.e., only diagonal elements. delta (torch.Tensor): Time step for discretization. integration_timesteps (torch.Tensor, optional): Not used in bilinear discretization. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - A_bar : The discretized system matrix. - gamma_bar : The input normalizer. """ Identity = torch.ones(A.shape[0], device=A.device) A_bar = (1 / (Identity + 0.5 * delta * A)) * (Identity - 0.5 * delta * A) gamma_bar = (1 / (Identity + 0.5 * delta * A)) * delta return A_bar, gamma_bar
[docs] def dirac( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None ) -> tuple[Tensor, float]: """ Dirac discretization method. .. math:: \\bar{A} &= \\exp(\\Delta A) \\\\ \\bar{\\gamma} &= 1.0 Reference: https://github.com/Efficient-Scalable-Machine-Learning/event-ssm/blob/main/event_ssm/ssm.py Args: A (torch.Tensor): Continuous-time system matrix. delta (torch.Tensor): Time step for discretization. integration_timesteps (torch.Tensor, optional): Not used in dirac discretization. Defaults to None. Returns: tuple[torch.Tensor, float]: A tuple containing: - A_bar : The discretized system matrix. - gamma_bar : The input normalizer (1.0). """ A_bar = torch.exp(delta * A) gamma_bar = 1.0 return A_bar, gamma_bar
[docs] def async_( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None ) -> tuple[Tensor, Tensor]: """ Asynchronous discretization method, introduced in https://arxiv.org/abs/2404.18508. This helps provide a strong inductive bias for asynchronous event-streams. .. math:: \\bar{A} &= \\exp(\\Delta \\cdot \\text{integration\\_timesteps} \\cdot A) \\\\ \\bar{\\gamma} &= A^{-1} (\\exp(\\Delta A) - I) This method is only for legacy reasons; it is not possible to use this method (or any other discretization with async timesteps) with LTI models. Args: A (torch.Tensor): Continuous-time system matrix. delta (torch.Tensor): Time step for discretization. integration_timesteps (torch.Tensor): Timesteps for async discretization, ideally of shape (B, L), i.e., difference in timesteps of events. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - A_bar : The discretized system matrix. - gamma_bar : The input normalizer. """ assert ( integration_timesteps is not None ), "Integration timesteps must be provided for async discretization." Identity = torch.ones(A.shape[0], device=A.device) A_bar = torch.exp(delta * integration_timesteps * A) gamma_bar = (1 / A) * (A_bar - Identity) return A_bar, gamma_bar
[docs] def no_discretization( A: Tensor, delta: Tensor, integration_timesteps: Optional[Tensor] = None ) -> tuple[Tensor, float]: """ No discretization method, identity operation. .. math:: \\bar{A} &= A \\\\ \\bar{\\gamma} &= 1.0 Args: A (torch.Tensor): Continuous-time system matrix. delta (torch.Tensor): Time step for discretization (unused). integration_timesteps (torch.Tensor, optional): Not used in no_discretization. Defaults to None. Returns: tuple[torch.Tensor, float]: A tuple containing: - A_bar : Same as A. - gamma_bar : 1.0, as B_bar = B. """ return A, 1.0
DISCRETIZE_FNS: dict[ str, Callable[ [Tensor, Tensor, Optional[Tensor]], Tuple[Tensor, Union[Tensor, float]] ], ] = { "zoh": zoh, "bilinear": bilinear, "dirac": dirac, "async": async_, "no_discretization": no_discretization, }