lrnnx.core.convolution module

FFT convolution with optimized einsum contractions. Ref.: https://arxiv.org/abs/2409.03377

fft_conv(equation: str, input: torch.Tensor) torch.Tensor

FFT based convolution operation.

Parameters:
  • equation (str) – Einsum equation for the convolution.

  • input (torch.Tensor) – Input tensor, shape (B, L, H) or (B, L, N).

  • *args – Either single kernel (L, H, H) or (K, B_norm / B_bar, C) tensors.

Returns:

Convolved output tensor, shape (B, L, H) or (B, L, N).

Return type:

torch.Tensor

opt_ssm_forward(x: torch.Tensor, K: torch.Tensor, B_: torch.Tensor, C: torch.Tensor) torch.Tensor[source]

Optimized FFT convolution.

Parameters:
  • x (torch.Tensor) – Input tensor, shape (B, L, H).

  • K (torch.Tensor) – Kernel tensor, shape (L, H, H) or (L, N).

  • B (torch.Tensor) – Normalized input projection matrix, shape (N, H).

  • C (torch.Tensor) – Output projection matrix, shape (H, N).

Returns:

Output tensor, shape (B, L, H).

Return type:

torch.Tensor

class FFTConvS4[source]

Bases: Module

Implements an FFT Convolution around a convolution kernel.

__init__(d_model, l_max=None, channels=1, swap_channels=False, transposed=True, dropout=0.0, tie_dropout=False, drop_kernel=0.0, kernel_type=None, param_config=None, kernel=None)[source]

Initialize FFTConvS4.

Parameters:
  • d_model (int) – Model dimension (in CNN terminology, “channels”).

  • l_max (int, optional) – Maximum kernel length. None for a global kernel. Defaults to None.

  • channels (int, optional) – Number of “heads”; SSM maps 1-dim to C-dim. Defaults to 1.

  • swap_channels (bool, optional) – Whether to swap channel ordering. Defaults to False.

  • transposed (bool, optional) – Backbone axis ordering. Defaults to True.

  • dropout (float, optional) – Dropout probability. Defaults to 0.0.

  • tie_dropout (bool, optional) – Tie dropout mask across sequence length. Defaults to False.

  • drop_kernel (float, optional) – Kernel dropout probability. Defaults to 0.0.

  • kernel_type (str, optional) – Kernel algorithm ('s4' for DPLR, 's4d' for diagonal). Defaults to None.

  • param_config (dict, optional) – References to SSM parameters (A, B, C, dt, P, etc.). Defaults to None.

  • kernel (str, optional) – Alternative kernel specification. Defaults to None.

forward(x, state=None, rate=1.0)[source]

Forward pass through FFTConvS4.

Parameters:
  • x (torch.Tensor) – Input tensor, shape (B, D, L) if self.transposed else (B, L, D).

  • state (torch.Tensor, optional) – Recurrent state. Defaults to None.

  • rate (float, optional) – Rate for kernel computation. Defaults to 1.0.

Returns:

A tuple containing:
  • y : Convolution output, shape (B, C, H, L).

  • next_state : State for recurrent mode, or None.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

setup_step()[source]
step(x, state)[source]

Step one time step as a recurrent model.

Intended to be used during validation.

Parameters:
Returns:

A tuple containing:
  • y : Output, shape (B, C, H).

  • next_state : Updated state, shape (B, H, N).

Return type:

tuple[torch.Tensor, torch.Tensor]

default_state(*batch_shape, device=None)[source]
property d_output