lrnnx.core.convolution module¶
FFT convolution with optimized einsum contractions. Ref.: https://arxiv.org/abs/2409.03377
- fft_conv(equation: str, input: torch.Tensor) torch.Tensor¶
FFT based convolution operation.
- Parameters:
equation (str) – Einsum equation for the convolution.
input (torch.Tensor) – Input tensor, shape
(B, L, H)or(B, L, N).*args – Either single kernel
(L, H, H)or(K, B_norm / B_bar, C)tensors.
- Returns:
Convolved output tensor, shape
(B, L, H)or(B, L, N).- Return type:
- opt_ssm_forward(x: torch.Tensor, K: torch.Tensor, B_: torch.Tensor, C: torch.Tensor) torch.Tensor[source]¶
Optimized FFT convolution.
- Parameters:
x (torch.Tensor) – Input tensor, shape
(B, L, H).K (torch.Tensor) – Kernel tensor, shape
(L, H, H)or(L, N).B (torch.Tensor) – Normalized input projection matrix, shape
(N, H).C (torch.Tensor) – Output projection matrix, shape
(H, N).
- Returns:
Output tensor, shape
(B, L, H).- Return type:
- class FFTConvS4[source]¶
Bases:
ModuleImplements an FFT Convolution around a convolution kernel.
- __init__(d_model, l_max=None, channels=1, swap_channels=False, transposed=True, dropout=0.0, tie_dropout=False, drop_kernel=0.0, kernel_type=None, param_config=None, kernel=None)[source]¶
Initialize FFTConvS4.
- Parameters:
d_model (int) – Model dimension (in CNN terminology, “channels”).
l_max (int, optional) – Maximum kernel length.
Nonefor a global kernel. Defaults to None.channels (int, optional) – Number of “heads”; SSM maps 1-dim to C-dim. Defaults to 1.
swap_channels (bool, optional) – Whether to swap channel ordering. Defaults to False.
transposed (bool, optional) – Backbone axis ordering. Defaults to True.
dropout (float, optional) – Dropout probability. Defaults to 0.0.
tie_dropout (bool, optional) – Tie dropout mask across sequence length. Defaults to False.
drop_kernel (float, optional) – Kernel dropout probability. Defaults to 0.0.
kernel_type (str, optional) – Kernel algorithm (
's4'for DPLR,'s4d'for diagonal). Defaults to None.param_config (dict, optional) – References to SSM parameters (A, B, C, dt, P, etc.). Defaults to None.
kernel (str, optional) – Alternative kernel specification. Defaults to None.
- forward(x, state=None, rate=1.0)[source]¶
Forward pass through FFTConvS4.
- Parameters:
x (torch.Tensor) – Input tensor, shape
(B, D, L)ifself.transposedelse(B, L, D).state (torch.Tensor, optional) – Recurrent state. Defaults to None.
rate (float, optional) – Rate for kernel computation. Defaults to 1.0.
- Returns:
- A tuple containing:
y : Convolution output, shape
(B, C, H, L).next_state : State for recurrent mode, or
None.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- step(x, state)[source]¶
Step one time step as a recurrent model.
Intended to be used during validation.
- Parameters:
x (torch.Tensor) – Input tensor, shape
(B, H).state (torch.Tensor) – Recurrent state, shape
(B, H, N).
- Returns:
- A tuple containing:
y : Output, shape
(B, C, H).next_state : Updated state, shape
(B, H, N).
- Return type:
- property d_output¶