lrnnx.ops.s4_utils module

cauchy_naive(v, z, w)[source]

Naive PyTorch fallback for Cauchy matrix multiplication.

Parameters:
Returns:

The sum v/(z-w), shape (..., L).

Return type:

torch.Tensor

log_vandermonde_naive(v, x, L, conj=True)[source]

Naive PyTorch fallback for log Vandermonde multiplication.

Parameters:
  • v (torch.Tensor) – Input tensor of shape (..., N).

  • x (torch.Tensor) – Input tensor of shape (..., N).

  • L (int) – Sequence length.

  • conj (bool, optional) – Whether to use conjugate symmetry. Defaults to True.

Returns:

The sum v * x^l, shape (..., L).

Return type:

torch.Tensor

log_vandermonde_transpose_naive(u, v, x, L)[source]

Naive PyTorch fallback for transposed log Vandermonde multiplication.

Parameters:
Returns:

Output tensor of shape (..., N).

Return type:

torch.Tensor

get_cauchy_kernel()[source]

Returns the best available Cauchy multiplication function.

Returns:

The Cauchy kernel function (CUDA, KeOps, or naive fallback).

Return type:

callable

get_vandermonde_kernel()[source]

Returns the best available Vandermonde multiplication function.

Returns:

The Vandermonde kernel function (CUDA, KeOps, or naive fallback).

Return type:

callable

get_vandermonde_transpose_kernel()[source]

Returns the best available transpose Vandermonde multiplication function.

Returns:

The transposed Vandermonde kernel function (KeOps or naive fallback).

Return type:

callable

LinearActivation(d_input, d_output, bias=True, transposed=False, activate=False)[source]

Returns a linear nn.Module with control over axes order, initialization, and activation.

Parameters:
  • d_input (int) – Input dimension.

  • d_output (int) – Output dimension.

  • bias (bool, optional) – Whether to use bias. Defaults to True.

  • transposed (bool, optional) – If True, uses Conv1d instead of Linear. Defaults to False.

  • activate (bool, optional) – Whether to append a GELU activation. Defaults to False.

Returns:

The configured linear/activation module.

Return type:

torch.nn.Module

class DropoutNd[source]

Bases: Module

N-dimensional dropout module.

Parameters:
  • p (float, optional) – Dropout probability. Defaults to 0.5.

  • tie (bool, optional) – Tie dropout mask across sequence lengths (Dropout1d/2d/3d). Defaults to True.

  • transposed (bool, optional) – Whether the sequence dimension is transposed. Defaults to True.

forward(X)[source]

Forward pass for DropoutNd.

Parameters:

X (torch.Tensor) – Input tensor of shape (batch, dim, lengths...).

Returns:

Tensor with dropout applied.

Return type:

torch.Tensor

power(L, A, v=None)[source]

Compute A^L and the scan sum_i A^i v_i.

Parameters:
  • L (int) – Power to raise A to.

  • A (torch.Tensor) – Input matrix of shape (..., N, N).

  • v (torch.Tensor, optional) – Vector for scan sum, shape (..., N, L). Defaults to None.

Returns:

A^L, or a tuple of (A^L, scan sum).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

transition(measure, N)[source]

Constructs A, B transition matrices for different measures.

Parameters:
  • measure (str) – Measure type (e.g., “legt”, “legs”, “fourier”).

  • N (int) – State dimension.

Returns:

Transition matrices A and B.

Return type:

tuple[np.ndarray, np.ndarray]

rank_correction(measure, N, rank=1, dtype=torch.float)[source]

Return low-rank matrix P such that A + PP^T is normal.

Parameters:
  • measure (str) – Measure type.

  • N (int) – State dimension.

  • rank (int, optional) – Rank of the correction. Defaults to 1.

  • dtype (torch.dtype, optional) – Data type. Defaults to torch.float.

Returns:

Rank correction matrix P.

Return type:

torch.Tensor

nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True, B_clip=2.0)[source]

Constructs NPLR form of HiPPO matrices.

Returns w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V, i.e. A = V[w - p q^*]V^*, B = V B.

Parameters:
  • measure (str) – Measure type.

  • N (int) – State dimension.

  • rank (int, optional) – Rank of the correction. Defaults to 1.

  • dtype (torch.dtype, optional) – Target data type. Defaults to torch.float.

  • diagonalize_precision (bool, optional) – Whether to diagonalize in double precision. Defaults to True.

  • B_clip (float, optional) – Clipping value for B. Defaults to 2.0.

Returns:

The W, P, B, and V matrices.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

dplr(init='hippo', N=64, rank=1, H=1, dtype=torch.float, real_random=False, real_scale=1.0, imag_random=False, imag_scale=1.0, B_random=False, B_init='constant', B_scale=1.0, P_scale=1.0, normalize=False)[source]

Directly construct a DPLR matrix.

Parameters:
  • init (str, optional) – Initialization method. Defaults to “hippo”.

  • N (int, optional) – State size. Defaults to 64.

  • rank (int, optional) – Rank for DPLR parameterization. Defaults to 1.

  • H (int, optional) – Number of independent SSM copies. Defaults to 1.

  • dtype (torch.dtype, optional) – Data type. Defaults to torch.float.

  • real_random (bool, optional) – Whether to randomize real part. Defaults to False.

  • real_scale (float, optional) – Scaling factor for real part. Defaults to 1.0.

  • imag_random (bool, optional) – Whether to randomize imaginary part. Defaults to False.

  • imag_scale (float, optional) – Scaling factor for imaginary part. Defaults to 1.0.

  • B_random (bool, optional) – Deprecated. Defaults to False.

  • B_init (str, optional) – Initialization method for B. Defaults to “constant”.

  • B_scale (float, optional) – Scaling factor for B. Defaults to 1.0.

  • P_scale (float, optional) – Scaling factor for P. Defaults to 1.0.

  • normalize (bool, optional) – Whether to normalize B. Defaults to False.

Returns:

Matrices A, P, B, V.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

ssm(init, N, R, H)[source]

Dispatcher to create single SSM initialization.

Parameters:
  • init (str) – Initialization method.

  • N (int) – State size.

  • R (int) – Rank (for DPLR parameterization).

  • H (int) – Number of independent SSM copies.

Returns:

Matrices A, P, B, V.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

combination(inits, N, R, S)[source]

Create combination of SSM initializations.

Parameters:
  • inits (str | list[str]) – Initialization methods.

  • N (int) – State size.

  • R (int) – Rank.

  • S (int) – Number of SSM copies.

Returns:

Combined matrices A, P, B, V.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

inv_transform(param, transform='none')[source]

Initialize a (positive) parameter under a transform.

Parameters:
  • param (torch.Tensor) – Parameter tensor.

  • transform (str, optional) – Transform type (“none”, “exp”, “relu”, “sigmoid”, “softplus”). Defaults to “none”.

Returns:

Transformed parameter.

Return type:

torch.Tensor

param_transform(param, transform='none')[source]

Get a (positive) parameter under a transform.

Parameters:
  • param (torch.Tensor) – Parameter tensor.

  • transform (str, optional) – Transform type. Defaults to “none”.

Returns:

Transformed parameter.

Return type:

torch.Tensor

init_dt(H, N, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform='exp', deterministic=False, dtype=torch.float)[source]

Initialize dt parameter.

Parameters:
  • H (int) – Model dimension (number of independent SSMs).

  • N (int) – State size.

  • dt_min (float, optional) – Minimum dt value. Defaults to 0.001.

  • dt_max (float, optional) – Maximum dt value. Defaults to 0.1.

  • dt_tie (bool, optional) – Whether to tie dt across dimensions. Defaults to True.

  • dt_transform (str, optional) – Transform type for dt. Defaults to “exp”.

  • deterministic (bool, optional) – Whether to use deterministic initialization. Defaults to False.

  • dtype (torch.dtype, optional) – Data type. Defaults to torch.float.

Returns:

Initialized (inverse transformed) dt parameter.

Return type:

torch.Tensor

init_ssm_dplr(N, H, n_ssm, channels, rank, init, deterministic=False, cdtype=torch.cfloat)[source]

Initialize DPLR (A, P, B, C) parameters and return broadcast repeat factor.

Parameters:
  • N (int) – State size.

  • H (int) – Model dimension.

  • n_ssm (int) – Number of independent SSM copies.

  • channels (int) – Number of channels.

  • rank (int) – Rank.

  • init (str) – Initialization method.

  • deterministic (bool, optional) – Whether to use deterministic initialization. Defaults to False.

  • cdtype (torch.dtype, optional) – Complex data type. Defaults to torch.cfloat.

Returns:

Tensors A, P, B, C.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

register_ssm_params(module, A, B, C, inv_dt, P, H, n_ssm, N, channels, rank, dt_fast, real_transform, imag_transform, is_real, verbose, l_max, diag=False)[source]

Register SSM parameters on a module.

Parameters:
  • module (torch.nn.Module) – The module to register parameters on.

  • A (torch.Tensor) – Tensor A.

  • B (torch.Tensor) – Tensor B.

  • C (torch.Tensor) – Tensor C.

  • inv_dt (torch.Tensor) – Inverse dt tensor.

  • P (torch.Tensor) – Tensor P.

  • H (int) – Model dimension.

  • n_ssm (int) – Number of independent SSMs.

  • N (int) – State size.

  • channels (int) – Number of channels.

  • rank (int) – Rank.

  • dt_fast (bool) – Whether dt is fast.

  • real_transform (str) – Transform for the real part.

  • imag_transform (str) – Transform for the imaginary part.

  • is_real (bool) – Whether parameters are real.

  • verbose (bool) – Whether to print construction details.

  • l_max (int) – Maximum sequence length.

  • diag (bool, optional) – If True, skip P registration and rank checks (for diagonal S4D). Defaults to False.

Returns:

The repeat broadcast factor.

Return type:

int

process_ssm_params(A_real, A_imag, B, C, inv_dt, real_transform='exp', imag_transform='none', dt_transform='exp', dt_fast=False, is_real=False, repeat_factor=1, rate=1.0)[source]

Process SSM parameters from stored form to usable form.

Parameters:
  • A_real (torch.Tensor) – Real part of A.

  • A_imag (torch.Tensor) – Imaginary part of A.

  • B (torch.Tensor) – Tensor B.

  • C (torch.Tensor) – Tensor C.

  • inv_dt (torch.Tensor) – Inverse dt tensor.

  • real_transform (str, optional) – Transform for A_real. Defaults to “exp”.

  • imag_transform (str, optional) – Transform for A_imag. Defaults to “none”.

  • dt_transform (str, optional) – Transform for dt. Defaults to “exp”.

  • dt_fast (bool, optional) – Whether dt is fast. Defaults to False.

  • is_real (bool, optional) – Whether parameters are real. Defaults to False.

  • repeat_factor (int, optional) – Broadcast repeat factor. Defaults to 1.

  • rate (float, optional) – Sampling rate. Defaults to 1.0.

Returns:

Processed dt, A, B, C, dtA.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

process_dplr_params(A_real, A_imag, B, C, P, inv_dt, real_transform='exp', imag_transform='none', dt_transform='exp', dt_fast=False, is_real=False, repeat_factor=1, rate=1.0)[source]

Process DPLR SSM parameters including P and Q matrices.

Parameters:
  • A_real (torch.Tensor) – Real part of A.

  • A_imag (torch.Tensor) – Imaginary part of A.

  • B (torch.Tensor) – Tensor B.

  • C (torch.Tensor) – Tensor C.

  • P (torch.Tensor) – Tensor P.

  • inv_dt (torch.Tensor) – Inverse dt tensor.

  • real_transform (str, optional) – Transform for A_real. Defaults to “exp”.

  • imag_transform (str, optional) – Transform for A_imag. Defaults to “none”.

  • dt_transform (str, optional) – Transform for dt. Defaults to “exp”.

  • dt_fast (bool, optional) – Whether dt is fast. Defaults to False.

  • is_real (bool, optional) – Whether parameters are real. Defaults to False.

  • repeat_factor (int, optional) – Broadcast repeat factor. Defaults to 1.

  • rate (float, optional) – Sampling rate. Defaults to 1.0.

Returns:

Processed dt, A, B, C, P, Q.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

setup_default_state(C, N, H, batch_shape, step_mode='dense')[source]

Create default SSM state.

Parameters:
  • C (torch.Tensor) – Tensor C to derive dtype and device from.

  • N (int) – State size.

  • H (int) – Model dimension.

  • batch_shape (tuple) – Batch shape dimensions.

  • step_mode (str, optional) – Step mode (“dense”, “linear”, etc.). Defaults to “dense”.

Returns:

Zero-initialized state tensor.

Return type:

torch.Tensor