"""
Utility functions for SSM initialization.
Reference: https://github.com/lindermanlab/S5
"""
import math
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
[docs]
def make_HiPPO(N: int) -> np.ndarray:
"""
Create a HiPPO-LegS matrix.
Args:
N (int): The dimension of the HiPPO matrix.
Returns:
numpy.ndarray: The generated HiPPO-LegS matrix of shape ``(N, N)``.
"""
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
[docs]
def make_NPLR_HiPPO(N: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Make NPLR representation of HiPPO-LegS.
Args:
N (int): The dimension of the HiPPO matrix.
Returns:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: A tuple containing the
HiPPO matrix, P vector, and B vector.
"""
hippo = make_HiPPO(N)
P = np.sqrt(np.arange(N) + 0.5)
B = np.sqrt(2 * np.arange(N) + 1.0)
return hippo, P, B
[docs]
def make_DPLR_HiPPO(
N: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Make DPLR representation of HiPPO-LegS.
Args:
N (int): The dimension of the HiPPO matrix.
Returns:
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: A tuple
containing Lambda, P_transformed, B_transformed, V, and B_orig.
"""
A, P, B = make_NPLR_HiPPO(N)
S = A + P[:, np.newaxis] * P[np.newaxis, :]
S_diag = np.diagonal(S)
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
Lambda_imag, V = np.linalg.eigh(S * -1j)
P_transformed = V.conj().T @ P
B_orig = B.copy()
B_transformed = V.conj().T @ B
Lambda = Lambda_real + 1j * Lambda_imag
return Lambda, P_transformed, B_transformed, V, B_orig
[docs]
def init_log_steps(
P: int,
dt_min: float = 0.001,
dt_max: float = 0.1,
) -> Tensor:
"""
Initialize learnable timescale parameters.
Args:
P (int): State dimension (number of timescales).
dt_min (float, optional): Minimum timescale. Defaults to 0.001.
dt_max (float, optional): Maximum timescale. Defaults to 0.1.
Returns:
torch.Tensor: Log-timescales of shape ``(P,)``.
"""
return torch.empty(P).uniform_(math.log(dt_min), math.log(dt_max))
[docs]
def init_VinvB(
Vinv: np.ndarray,
local_P: int,
H: int,
) -> Tensor:
"""
Initialize B_tilde = V^{-1} @ B with lecun-style scaling.
Args:
Vinv (numpy.ndarray): Inverse eigenvectors of shape ``(P, local_P)``.
local_P (int): Local state dimension (2*P if conj_sym else P).
H (int): Hidden dimension.
Returns:
torch.Tensor: B_tilde of shape ``(P, H, 2)`` for real/imag parameterization.
"""
B = np.random.randn(local_P, H).astype(np.float32) * np.sqrt(1.0 / local_P)
VinvB = Vinv @ B
return torch.tensor(
np.stack([VinvB.real, VinvB.imag], axis=-1).astype(np.float32)
)
[docs]
def init_CV(
V: np.ndarray,
local_P: int,
H: int,
) -> Tensor:
"""
Initialize C_tilde = C @ V with truncated normal.
Args:
V (numpy.ndarray): Eigenvectors of shape ``(local_P, P)``.
local_P (int): Local state dimension (2*P if conj_sym else P).
H (int): Hidden dimension.
Returns:
torch.Tensor: C_tilde of shape ``(H, P, 2)`` for real/imag parameterization.
"""
C = (
np.random.randn(H, local_P).astype(np.float32)
+ 1j * np.random.randn(H, local_P).astype(np.float32)
) * np.sqrt(1.0 / local_P)
CV = C @ V
return torch.tensor(
np.stack([CV.real, CV.imag], axis=-1).astype(np.float32)
)