lrnnx.utils.init module

Utility functions for SSM initialization. Reference: https://github.com/lindermanlab/S5

make_HiPPO(N: int) numpy.ndarray[source]

Create a HiPPO-LegS matrix.

Parameters:

N (int) – The dimension of the HiPPO matrix.

Returns:

The generated HiPPO-LegS matrix of shape (N, N).

Return type:

numpy.ndarray

make_NPLR_HiPPO(N: int) Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Make NPLR representation of HiPPO-LegS.

Parameters:

N (int) – The dimension of the HiPPO matrix.

Returns:

A tuple containing the

HiPPO matrix, P vector, and B vector.

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]

make_DPLR_HiPPO(N: int) Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray][source]

Make DPLR representation of HiPPO-LegS.

Parameters:

N (int) – The dimension of the HiPPO matrix.

Returns:

A tuple

containing Lambda, P_transformed, B_transformed, V, and B_orig.

Return type:

tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]

init_log_steps(P: int, dt_min: float = 0.001, dt_max: float = 0.1) torch.Tensor[source]

Initialize learnable timescale parameters.

Parameters:
  • P (int) – State dimension (number of timescales).

  • dt_min (float, optional) – Minimum timescale. Defaults to 0.001.

  • dt_max (float, optional) – Maximum timescale. Defaults to 0.1.

Returns:

Log-timescales of shape (P,).

Return type:

torch.Tensor

init_VinvB(Vinv: numpy.ndarray, local_P: int, H: int) torch.Tensor[source]

Initialize B_tilde = V^{-1} @ B with lecun-style scaling.

Parameters:
  • Vinv (numpy.ndarray) – Inverse eigenvectors of shape (P, local_P).

  • local_P (int) – Local state dimension (2*P if conj_sym else P).

  • H (int) – Hidden dimension.

Returns:

B_tilde of shape (P, H, 2) for real/imag parameterization.

Return type:

torch.Tensor

init_CV(V: numpy.ndarray, local_P: int, H: int) torch.Tensor[source]

Initialize C_tilde = C @ V with truncated normal.

Parameters:
  • V (numpy.ndarray) – Eigenvectors of shape (local_P, P).

  • local_P (int) – Local state dimension (2*P if conj_sym else P).

  • H (int) – Hidden dimension.

Returns:

C_tilde of shape (H, P, 2) for real/imag parameterization.

Return type:

torch.Tensor