Source code for lrnnx.core.convolution

"""
FFT convolution with optimized einsum contractions.
Ref.: https://arxiv.org/abs/2409.03377
"""

import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from torch.nn import functional as F

from lrnnx.ops.s4_kernel_interface import kernel_registry
from lrnnx.ops.s4_utils import DropoutNd

contract = torch.einsum


@torch.compiler.disable
def fft_conv(equation: str, input: Tensor, *args) -> Tensor:
    """
    FFT based convolution operation.

    Args:
        equation (str): Einsum equation for the convolution.
        input (torch.Tensor): Input tensor, shape ``(B, L, H)`` or ``(B, L, N)``.
        *args: Either single kernel ``(L, H, H)`` or ``(K, B_norm / B_bar, C)`` tensors.

    Returns:
        torch.Tensor: Convolved output tensor, shape ``(B, L, H)`` or ``(B, L, N)``.
    """
    L = input.shape[1]

    input_f = torch.fft.fft(input, 2 * L, dim=1)  # (B, 2L, H)
    args = tuple(arg.cfloat() for arg in args)

    if len(args) == 1:
        kernel = args[0]
        kernel_f = torch.fft.fft(kernel, 2 * L, dim=0)  # (2L, H, H)
        output_f = torch.einsum(equation, input_f, kernel_f)  # (B, 2L, H)

    else:
        K, B_norm, C = args
        K_f = torch.fft.fft(K, 2 * L, dim=1)  # (N, 2L)
        output_f = torch.einsum(
            equation, input_f, K_f, B_norm, C
        )  # (B, 2L, H)

    output = torch.fft.ifft(output_f, dim=1)  # (B, 2L, H)
    return output[:, :L, :]  # (B, L, H)


[docs] def opt_ssm_forward(x: Tensor, K: Tensor, B_: Tensor, C: Tensor) -> Tensor: """ Optimized FFT convolution. Args: x (torch.Tensor): Input tensor, shape ``(B, L, H)``. K (torch.Tensor): Kernel tensor, shape ``(L, H, H)`` or ``(L, N)``. B_ (torch.Tensor): Normalized input projection matrix, shape ``(N, H)``. C (torch.Tensor): Output projection matrix, shape ``(H, N)``. Returns: torch.Tensor: Output tensor, shape ``(B, L, H)``. """ B, _, H_in = x.shape H_out, N = C.shape if (1 / H_in + 1 / H_out) > (1 / B + 1 / N): if H_in * H_out <= N: # strategy 1 kernel = torch.einsum("on,nl,ni->loi", C, K, B_).real # (L, H, H) return fft_conv("bli,loi->blo", x, kernel).real # (B, L, H) else: if N <= H_in: # strategy 2 x_proj = torch.einsum( "blh,nh->bln", x.to(B_.dtype), B_ ) # (B, L, N) x_conv = fft_conv("bln,ln->bln", x_proj, K.T) # (B, L, N) return torch.einsum("bln,hn->blh", x_conv, C).real # (B, L, H) # fallback return fft_conv("blh,nl,nh,on->blo", x, K, B_, C).real # (B, L, H)
[docs] class FFTConvS4(nn.Module): """Implements an FFT Convolution around a convolution kernel."""
[docs] def __init__( self, d_model, l_max=None, channels=1, swap_channels=False, transposed=True, dropout=0.0, tie_dropout=False, drop_kernel=0.0, kernel_type=None, param_config=None, kernel=None, **kernel_args, ): """ Initialize FFTConvS4. Args: d_model (int): Model dimension (in CNN terminology, "channels"). l_max (int, optional): Maximum kernel length. ``None`` for a global kernel. Defaults to None. channels (int, optional): Number of "heads"; SSM maps 1-dim to C-dim. Defaults to 1. swap_channels (bool, optional): Whether to swap channel ordering. Defaults to False. transposed (bool, optional): Backbone axis ordering. Defaults to True. dropout (float, optional): Dropout probability. Defaults to 0.0. tie_dropout (bool, optional): Tie dropout mask across sequence length. Defaults to False. drop_kernel (float, optional): Kernel dropout probability. Defaults to 0.0. kernel_type (str, optional): Kernel algorithm (``'s4'`` for DPLR, ``'s4d'`` for diagonal). Defaults to None. param_config (dict, optional): References to SSM parameters (A, B, C, dt, P, etc.). Defaults to None. kernel (str, optional): Alternative kernel specification. Defaults to None. **kernel_args: Additional arguments forwarded to the kernel class. """ super().__init__() self.d_model = d_model self.L = self.l_max = l_max self.channels = channels self.transposed = transposed self.swap_channels = swap_channels if param_config is not None: if kernel_type is None: raise ValueError( "kernel_type must be provided with param_config" ) kernel_cls = kernel_registry[kernel_type] self.kernel = kernel_cls( d_model=d_model, l_max=l_max, channels=channels, param_config=param_config, ) else: if kernel is None: raise ValueError( "Either param_config or kernel must be provided" ) kernel_cls = kernel_registry[kernel] self.kernel = kernel_cls( d_model=self.d_model, l_max=self.l_max, channels=channels, **kernel_args, ) dropout_fn = DropoutNd if tie_dropout else nn.Dropout self.drop_kernel = ( nn.Dropout(drop_kernel) if drop_kernel > 0.0 else nn.Identity() )
[docs] def forward( self, x, state=None, rate=1.0, **kwargs ): # absorbs return_output and transformer src mask """ Forward pass through FFTConvS4. Args: x (torch.Tensor): Input tensor, shape ``(B, D, L)`` if ``self.transposed`` else ``(B, L, D)``. state (torch.Tensor, optional): Recurrent state. Defaults to None. rate (float, optional): Rate for kernel computation. Defaults to 1.0. **kwargs: Additional keyword arguments (absorbs return_output, src mask, etc.). Returns: tuple[torch.Tensor, torch.Tensor | None]: A tuple containing: - y : Convolution output, shape ``(B, C, H, L)``. - next_state : State for recurrent mode, or ``None``. """ # Always work with (B D L) dimension in this module if not self.transposed: x = x.transpose(-1, -2) L = x.size(-1) # Compute SS Kernel l_kernel = L if self.L is None else min(L, round(self.L / rate)) k, k_state = self.kernel( L=l_kernel, rate=rate, state=state ) # (C H L) (B C H L) # Kernel dropout k = self.drop_kernel(k) # FFT convolution (core operation) k_f = torch.fft.rfft(k, n=l_kernel + L) # (C H L) x_f = torch.fft.rfft(x, n=l_kernel + L) # (B H L) y_f = contract("bhl,chl->bchl", x_f, k_f) y = torch.fft.irfft(y_f, n=l_kernel + L)[..., :L] # (B C H L) # Compute state update if state is not None: y = y + k_state next_state = self.kernel.forward_state(x, state) else: next_state = None return y, next_state
[docs] def setup_step(self, **kwargs): self.kernel._setup_step(**kwargs)
[docs] def step(self, x, state): """ Step one time step as a recurrent model. Intended to be used during validation. Args: x (torch.Tensor): Input tensor, shape ``(B, H)``. state (torch.Tensor): Recurrent state, shape ``(B, H, N)``. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - y : Output, shape ``(B, C, H)``. - next_state : Updated state, shape ``(B, H, N)``. """ y, next_state = self.kernel.step(x, state) # (B C H) return y, next_state
[docs] def default_state(self, *batch_shape, device=None): return self.kernel.default_state(*batch_shape)
@property def d_output(self): return self.d_model * self.channels