lrnnx.ops.s4_utils module¶
- cauchy_naive(v, z, w)[source]¶
Naive PyTorch fallback for Cauchy matrix multiplication.
- Parameters:
v (torch.Tensor) – Input tensor of shape
(..., N).z (torch.Tensor) – Input tensor of shape
(..., L).w (torch.Tensor) – Input tensor of shape
(..., N).
- Returns:
The sum v/(z-w), shape
(..., L).- Return type:
- log_vandermonde_naive(v, x, L, conj=True)[source]¶
Naive PyTorch fallback for log Vandermonde multiplication.
- Parameters:
v (torch.Tensor) – Input tensor of shape
(..., N).x (torch.Tensor) – Input tensor of shape
(..., N).L (int) – Sequence length.
conj (bool, optional) – Whether to use conjugate symmetry. Defaults to True.
- Returns:
The sum v * x^l, shape
(..., L).- Return type:
- log_vandermonde_transpose_naive(u, v, x, L)[source]¶
Naive PyTorch fallback for transposed log Vandermonde multiplication.
- Parameters:
u (torch.Tensor) – Input tensor of shape
(..., L).v (torch.Tensor) – Input tensor of shape
(..., N).x (torch.Tensor) – Input tensor of shape
(..., N).L (int) – Sequence length.
- Returns:
Output tensor of shape
(..., N).- Return type:
- get_cauchy_kernel()[source]¶
Returns the best available Cauchy multiplication function.
- Returns:
The Cauchy kernel function (CUDA, KeOps, or naive fallback).
- Return type:
callable
- get_vandermonde_kernel()[source]¶
Returns the best available Vandermonde multiplication function.
- Returns:
The Vandermonde kernel function (CUDA, KeOps, or naive fallback).
- Return type:
callable
- get_vandermonde_transpose_kernel()[source]¶
Returns the best available transpose Vandermonde multiplication function.
- Returns:
The transposed Vandermonde kernel function (KeOps or naive fallback).
- Return type:
callable
- LinearActivation(d_input, d_output, bias=True, transposed=False, activate=False)[source]¶
Returns a linear nn.Module with control over axes order, initialization, and activation.
- Parameters:
d_input (int) – Input dimension.
d_output (int) – Output dimension.
bias (bool, optional) – Whether to use bias. Defaults to True.
transposed (bool, optional) – If True, uses Conv1d instead of Linear. Defaults to False.
activate (bool, optional) – Whether to append a GELU activation. Defaults to False.
- Returns:
The configured linear/activation module.
- Return type:
- class DropoutNd[source]¶
Bases:
ModuleN-dimensional dropout module.
- Parameters:
- forward(X)[source]¶
Forward pass for DropoutNd.
- Parameters:
X (torch.Tensor) – Input tensor of shape
(batch, dim, lengths...).- Returns:
Tensor with dropout applied.
- Return type:
- power(L, A, v=None)[source]¶
Compute A^L and the scan sum_i A^i v_i.
- Parameters:
L (int) – Power to raise A to.
A (torch.Tensor) – Input matrix of shape
(..., N, N).v (torch.Tensor, optional) – Vector for scan sum, shape
(..., N, L). Defaults to None.
- Returns:
A^L, or a tuple of (A^L, scan sum).
- Return type:
- rank_correction(measure, N, rank=1, dtype=torch.float)[source]¶
Return low-rank matrix P such that A + PP^T is normal.
- Parameters:
measure (str) – Measure type.
N (int) – State dimension.
rank (int, optional) – Rank of the correction. Defaults to 1.
dtype (torch.dtype, optional) – Data type. Defaults to torch.float.
- Returns:
Rank correction matrix P.
- Return type:
- nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True, B_clip=2.0)[source]¶
Constructs NPLR form of HiPPO matrices.
Returns w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V, i.e. A = V[w - p q^*]V^*, B = V B.
- Parameters:
measure (str) – Measure type.
N (int) – State dimension.
rank (int, optional) – Rank of the correction. Defaults to 1.
dtype (torch.dtype, optional) – Target data type. Defaults to torch.float.
diagonalize_precision (bool, optional) – Whether to diagonalize in double precision. Defaults to True.
B_clip (float, optional) – Clipping value for B. Defaults to 2.0.
- Returns:
The W, P, B, and V matrices.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- dplr(init='hippo', N=64, rank=1, H=1, dtype=torch.float, real_random=False, real_scale=1.0, imag_random=False, imag_scale=1.0, B_random=False, B_init='constant', B_scale=1.0, P_scale=1.0, normalize=False)[source]¶
Directly construct a DPLR matrix.
- Parameters:
init (str, optional) – Initialization method. Defaults to “hippo”.
N (int, optional) – State size. Defaults to 64.
rank (int, optional) – Rank for DPLR parameterization. Defaults to 1.
H (int, optional) – Number of independent SSM copies. Defaults to 1.
dtype (torch.dtype, optional) – Data type. Defaults to torch.float.
real_random (bool, optional) – Whether to randomize real part. Defaults to False.
real_scale (float, optional) – Scaling factor for real part. Defaults to 1.0.
imag_random (bool, optional) – Whether to randomize imaginary part. Defaults to False.
imag_scale (float, optional) – Scaling factor for imaginary part. Defaults to 1.0.
B_random (bool, optional) – Deprecated. Defaults to False.
B_init (str, optional) – Initialization method for B. Defaults to “constant”.
B_scale (float, optional) – Scaling factor for B. Defaults to 1.0.
P_scale (float, optional) – Scaling factor for P. Defaults to 1.0.
normalize (bool, optional) – Whether to normalize B. Defaults to False.
- Returns:
Matrices A, P, B, V.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- ssm(init, N, R, H)[source]¶
Dispatcher to create single SSM initialization.
- Parameters:
- Returns:
Matrices A, P, B, V.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- combination(inits, N, R, S)[source]¶
Create combination of SSM initializations.
- Parameters:
- Returns:
Combined matrices A, P, B, V.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- inv_transform(param, transform='none')[source]¶
Initialize a (positive) parameter under a transform.
- Parameters:
param (torch.Tensor) – Parameter tensor.
transform (str, optional) – Transform type (“none”, “exp”, “relu”, “sigmoid”, “softplus”). Defaults to “none”.
- Returns:
Transformed parameter.
- Return type:
- param_transform(param, transform='none')[source]¶
Get a (positive) parameter under a transform.
- Parameters:
param (torch.Tensor) – Parameter tensor.
transform (str, optional) – Transform type. Defaults to “none”.
- Returns:
Transformed parameter.
- Return type:
- init_dt(H, N, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', deterministic=False, dtype=torch.float)[source]¶
Initialize dt parameter.
- Parameters:
H (int) – Model dimension (number of independent SSMs).
N (int) – State size.
dt_min (float, optional) – Minimum dt value. Defaults to 0.001.
dt_max (float, optional) – Maximum dt value. Defaults to 0.1.
dt_tie (bool, optional) – Whether to tie dt across dimensions. Defaults to True.
dt_transform (str, optional) – Transform type for dt. Defaults to “exp”.
deterministic (bool, optional) – Whether to use deterministic initialization. Defaults to False.
dtype (torch.dtype, optional) – Data type. Defaults to torch.float.
- Returns:
Initialized (inverse transformed) dt parameter.
- Return type:
- init_ssm_dplr(N, H, n_ssm, channels, rank, init, deterministic=False, cdtype=torch.cfloat)[source]¶
Initialize DPLR (A, P, B, C) parameters and return broadcast repeat factor.
- Parameters:
N (int) – State size.
H (int) – Model dimension.
n_ssm (int) – Number of independent SSM copies.
channels (int) – Number of channels.
rank (int) – Rank.
init (str) – Initialization method.
deterministic (bool, optional) – Whether to use deterministic initialization. Defaults to False.
cdtype (torch.dtype, optional) – Complex data type. Defaults to torch.cfloat.
- Returns:
Tensors A, P, B, C.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- register_ssm_params(module, A, B, C, inv_dt, P, H, n_ssm, N, channels, rank, dt_fast, real_transform, imag_transform, is_real, verbose, l_max, diag=False)[source]¶
Register SSM parameters on a module.
- Parameters:
module (torch.nn.Module) – The module to register parameters on.
A (torch.Tensor) – Tensor A.
B (torch.Tensor) – Tensor B.
C (torch.Tensor) – Tensor C.
inv_dt (torch.Tensor) – Inverse dt tensor.
P (torch.Tensor) – Tensor P.
H (int) – Model dimension.
n_ssm (int) – Number of independent SSMs.
N (int) – State size.
channels (int) – Number of channels.
rank (int) – Rank.
dt_fast (bool) – Whether dt is fast.
real_transform (str) – Transform for the real part.
imag_transform (str) – Transform for the imaginary part.
is_real (bool) – Whether parameters are real.
verbose (bool) – Whether to print construction details.
l_max (int) – Maximum sequence length.
diag (bool, optional) – If True, skip P registration and rank checks (for diagonal S4D). Defaults to False.
- Returns:
The repeat broadcast factor.
- Return type:
- process_ssm_params(A_real, A_imag, B, C, inv_dt, real_transform='exp', imag_transform='none', dt_transform='exp', dt_fast=False, is_real=False, repeat_factor=1, rate=1.0)[source]¶
Process SSM parameters from stored form to usable form.
- Parameters:
A_real (torch.Tensor) – Real part of A.
A_imag (torch.Tensor) – Imaginary part of A.
B (torch.Tensor) – Tensor B.
C (torch.Tensor) – Tensor C.
inv_dt (torch.Tensor) – Inverse dt tensor.
real_transform (str, optional) – Transform for A_real. Defaults to “exp”.
imag_transform (str, optional) – Transform for A_imag. Defaults to “none”.
dt_transform (str, optional) – Transform for dt. Defaults to “exp”.
dt_fast (bool, optional) – Whether dt is fast. Defaults to False.
is_real (bool, optional) – Whether parameters are real. Defaults to False.
repeat_factor (int, optional) – Broadcast repeat factor. Defaults to 1.
rate (float, optional) – Sampling rate. Defaults to 1.0.
- Returns:
Processed dt, A, B, C, dtA.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- process_dplr_params(A_real, A_imag, B, C, P, inv_dt, real_transform='exp', imag_transform='none', dt_transform='exp', dt_fast=False, is_real=False, repeat_factor=1, rate=1.0)[source]¶
Process DPLR SSM parameters including P and Q matrices.
- Parameters:
A_real (torch.Tensor) – Real part of A.
A_imag (torch.Tensor) – Imaginary part of A.
B (torch.Tensor) – Tensor B.
C (torch.Tensor) – Tensor C.
P (torch.Tensor) – Tensor P.
inv_dt (torch.Tensor) – Inverse dt tensor.
real_transform (str, optional) – Transform for A_real. Defaults to “exp”.
imag_transform (str, optional) – Transform for A_imag. Defaults to “none”.
dt_transform (str, optional) – Transform for dt. Defaults to “exp”.
dt_fast (bool, optional) – Whether dt is fast. Defaults to False.
is_real (bool, optional) – Whether parameters are real. Defaults to False.
repeat_factor (int, optional) – Broadcast repeat factor. Defaults to 1.
rate (float, optional) – Sampling rate. Defaults to 1.0.
- Returns:
Processed dt, A, B, C, P, Q.
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- setup_default_state(C, N, H, batch_shape, step_mode='dense')[source]¶
Create default SSM state.
- Parameters:
C (torch.Tensor) – Tensor C to derive dtype and device from.
N (int) – State size.
H (int) – Model dimension.
batch_shape (tuple) – Batch shape dimensions.
step_mode (str, optional) – Step mode (“dense”, “linear”, etc.). Defaults to “dense”.
- Returns:
Zero-initialized state tensor.
- Return type: