Source code for lrnnx.models.lti.s4

"""
Taken from the original S4 implementation and modified to fit into the LRNNX framework.
https://github.com/state-spaces/s4
"""

import math
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from lrnnx.core.convolution import FFTConvS4
from lrnnx.models.lti.base import LTI_LRNN
from lrnnx.ops.s4_utils import (
    DropoutNd,
    LinearActivation,
    init_dt,
    init_ssm_dplr,
    register_ssm_params,
)

contract = torch.einsum


[docs] class S4(LTI_LRNN): """ General block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others. Other options are all experimental and should not need to be configured. Example: >>> model = S4(d_model=64, d_state=64, l_max=1024) >>> x = torch.randn(2, 1024, 64) >>> y = model(x) >>> y.shape torch.Size([2, 1024, 64]) """
[docs] def __init__( self, d_model, bottleneck=None, gate=None, final_act="glu", postact=None, dropout=0.0, tie_dropout=False, transposed=True, # Kernel/SSM configuration args l_max=None, channels=1, d_state=64, dt_min=0.001, dt_max=0.1, dt_tie=True, dt_transform="exp", dt_fast=False, rank=1, n_ssm=None, init="legs", deterministic=False, real_transform="exp", imag_transform="none", is_real=False, lr=None, wd=0.0, verbose=True, **layer_args, # Any remaining args for FFTConv ): """ Initialize S4 block. Args: d_model (int): Model dimension. bottleneck (int, optional): Reduce dimension of inner layer (e.g. used in GSS). Defaults to None. gate (int, optional): Add multiplicative gating (e.g. used in GSS). Defaults to None. final_act (str, optional): Activation after final linear layer. ``'id'`` for no activation, ``None`` for no linear layer at all. Defaults to ``"glu"``. postact (str, optional): Deprecated, use *final_act*. Defaults to None. dropout (float, optional): Standard dropout argument. Defaults to 0.0. tie_dropout (bool, optional): Tie dropout mask across sequence length, emulating ``nn.Dropout1d``. Defaults to False. transposed (bool, optional): Backbone axis ordering ``(B, L, H)`` (False) or ``(B, H, L)`` (True). Defaults to True. l_max (int, optional): Maximum sequence length for the kernel. Defaults to None. channels (int, optional): Number of channels/heads. Defaults to 1. d_state (int, optional): State dimension (N). Defaults to 64. dt_min (float, optional): Minimum value for dt initialization. Defaults to 0.001. dt_max (float, optional): Maximum value for dt initialization. Defaults to 0.1. dt_tie (bool, optional): Tie dt across channels. Defaults to True. dt_transform (str, optional): Transformation to apply to dt. Defaults to ``"exp"``. dt_fast (bool, optional): Fast dt initialization. Defaults to False. rank (int, optional): Rank of the low-rank correction for DPLR. Defaults to 1. n_ssm (int, optional): Number of independent SSMs. Defaults to None. init (str, optional): Initialization method for the A matrix (e.g., ``"legs"``). Defaults to ``"legs"``. deterministic (bool, optional): Use deterministic initialization. Defaults to False. real_transform (str, optional): Transformation for the real part of A. Defaults to ``"exp"``. imag_transform (str, optional): Transformation for the imaginary part of A. Defaults to ``"none"``. is_real (bool, optional): Whether to use real-valued SSMs. Defaults to False. lr (float, optional): Specific learning rate for SSM parameters. Defaults to None. wd (float, optional): Specific weight decay for SSM parameters. Defaults to 0.0. verbose (bool, optional): Print initialization information. Defaults to True. **layer_args: Any remaining args passed directly to FFTConv. """ super().__init__( discretization="no_discretization" ) # discretization is unused in S4 self.d_model = d_model self.transposed = transposed self.gate = gate self.bottleneck = bottleneck # Store config needed for kernel self.l_max = l_max self.channels = channels self.N = d_state self.dtype, self.cdtype = torch.float, torch.cfloat self.dt_fast = dt_fast self.real_transform = real_transform self.imag_transform = imag_transform self.is_real = is_real self.deterministic = deterministic self.dt_min = dt_min self.dt_max = dt_max self.dt_tie = dt_tie self.dt_transform = dt_transform self.rank = rank self.H = d_model self.n_ssm = n_ssm if n_ssm is not None else self.H self.init = init self.verbose = verbose if bottleneck is not None: self.d_model = self.d_model // bottleneck self.input_linear = LinearActivation( self.d_model, self.d_model, transposed=False, activate=False, ) # Initialize dt inv_dt = init_dt( self.H, self.N, self.dt_min, self.dt_max, self.dt_tie, self.dt_transform, self.deterministic, self.dtype, ) _fft_conv_keys = { "kernel", "swap_channels", "drop_kernel", # FFTConvS4 "d_model", "l_max", "channels", "transposed", # also FFTConvS4 "d_state", "dt_min", "dt_max", "dt_tie", "dt_transform", # S4KernelBase (dt) "dt_fast", "rank", "n_ssm", "init", # S4KernelBase (A,B,C) "real_transform", "imag_transform", "is_real", # S4KernelBase (transforms) "deterministic", "verbose", "lr", "wd", # S4KernelBase (misc) "disc", # S4DKernel } init_args = { k: v for k, v in layer_args.items() if k not in _fft_conv_keys } # Initialize A, P, B, C A, P, B, C = init_ssm_dplr( self.N, self.H, self.n_ssm, self.channels, self.rank, self.init, self.deterministic, self.cdtype, **init_args, ) # Halve N for conjugate symmetry self.N //= 2 self.repeat = register_ssm_params( self, # Register on S4 module A, B, C, inv_dt, P, self.H, self.n_ssm, self.N, self.channels, self.rank, self.dt_fast, self.real_transform, self.imag_transform, self.is_real, self.verbose, self.l_max, diag=False, # S4 uses DPLR (not diagonal) ) if gate is not None: self.input_gate = LinearActivation( self.d_model, self.d_model * gate, transposed=False, activate=True, ) # D parameter and convolution activation self.D = nn.Parameter(torch.randn(channels, d_model)) self.conv_activation = nn.GELU() # Create FFTConvS4 with parameter references self.layer = FFTConvS4( d_model=d_model, l_max=l_max, channels=channels, transposed=False, dropout=dropout, tie_dropout=tie_dropout, kernel_type="s4", param_config={ "A_real": self.A_real, "A_imag": self.A_imag if not self.is_real else None, "B": self.B, "C": self.C, "P": self.P, "inv_dt": self.inv_dt, "N": self.N, "H": self.H, "channels": self.channels, "rank": self.rank, "repeat": self.repeat, "dt_fast": self.dt_fast, "real_transform": self.real_transform, "imag_transform": self.imag_transform, "dt_transform": self.dt_transform, "is_real": self.is_real, "deterministic": self.deterministic, "verbose": self.verbose, }, **layer_args, ) # Check if we need output_gate if gate is not None: if self.layer.d_output != self.d_model * gate: self.output_gate = LinearActivation( self.d_model, self.d_model * gate, transposed=False, activate=False, ) else: self.output_gate = nn.Identity() # Activation after (optional) multiplication by gate branch self.mult_activation = nn.GELU() dropout_fn = ( partial(DropoutNd, transposed=False) if tie_dropout else nn.Dropout ) self.drop = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() # position-wise output transform to mix features if postact is not None: assert final_act is None print( "Warning: 'postact' option changed to 'final_act' and will be removed in a future version." ) final_act, postact = postact, final_act if final_act is None: self.output_linear = nn.Identity() else: self.output_linear = LinearActivation( ( self.d_model * gate if gate is not None else self.layer.d_output ), self.d_model, transposed=False, activate=True, )
[docs] def forward(self, x, lengths=None, **kwargs): """ Forward pass of the S4 block. Args: x (torch.Tensor): Input tensor of shape ``(B, H, L)`` if ``self.transposed`` else ``(B, L, H)``. lengths (torch.Tensor | int, optional): Lengths of the sequences in the batch for padding masking. Defaults to None. **kwargs: Additional arguments absorbing return_output and transformer src mask. Returns: tuple[torch.Tensor, torch.Tensor | None]: A tuple containing: - y : Output tensor of the same shape as x. - state : The next recurrent state, or None. """ if self.transposed: x = rearrange(x, "b d ... -> b ... d") L = x.size(1) # Mask out padding tokens if isinstance(lengths, int): if lengths != L: lengths = torch.tensor( lengths, dtype=torch.long, device=x.device ) else: lengths = None if lengths is not None: assert ( isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, x.size(0)] ) mask = torch.where( torch.arange(L, device=lengths.device)[:, None] < lengths[:, None, None], 1.0, 0.0, ) x = x * mask if self.gate is not None: v = self.input_gate(x) if self.bottleneck is not None: x = self.input_linear(x) y, state = self.layer( x, **kwargs ) # (B C H L) in transposed=False mode # Post-convolution operations # Add D term x_for_D = x.transpose(-1, -2) # (B L H) -> (B H L) y = y + contract("bhl,ch->bchl", x_for_D, self.D) # Reshape to flatten channels if self.layer.swap_channels: y = rearrange(y, "b c h l -> b (h c) l") else: y = rearrange(y, "b c h l -> b (c h) l") # Transpose back to (B L H) format y = y.transpose(-1, -2) # (B C*H L) -> (B L C*H) # Apply convolution activation y = self.conv_activation(y) if self.gate is not None: y = self.output_gate(y) y = y * v y = self.mult_activation(y) y = self.drop(y) y = self.output_linear(y) if self.transposed: y = rearrange(y, "b d ... -> b ... d") return y, state
[docs] def step( self, x: torch.Tensor, inference_cache: dict, **kwargs, ) -> tuple: """ Perform a single recurrent step of the S4 model. Args: x (torch.Tensor): Input at current timestep, shape ``(B, H)``. inference_cache (dict): Cache from ``allocate_inference_cache()``. **kwargs: Additional keyword arguments. Returns: tuple[torch.Tensor, dict]: A tuple containing: - y_t : Output tensor at the current timestep of shape ``(B, H)``. - inference_cache : Updated cache dictionary. """ state = inference_cache["lrnn_state"] if self.gate is not None: v = self.input_gate(x) if self.bottleneck is not None: x = self.input_linear(x) y, next_state = self.layer.step(x, state) # (B C H) # Post-convolution operations # Add D term y = y + x.unsqueeze(-2) * self.D # Reshape to flatten channels y = rearrange(y, "b c h -> b (c h)") # Apply convolution activation y = self.conv_activation(y) if self.gate is not None: y = self.output_gate(y) y = y * v y = self.mult_activation(y) y = self.drop(y) y = self.output_linear(y) inference_cache["lrnn_state"].copy_(next_state) return y, inference_cache
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int = 1, dtype=None, **kwargs, ) -> dict: """ Allocate cache for step-by-step inference. Calls ``setup_step()`` to prepare discrete-time matrices (dA, dB, dC), then creates a zero-initialised hidden state. Args: batch_size (int): Batch size for inference. max_seqlen (int, optional): Unused, kept for interface consistency. Defaults to 1. dtype (torch.dtype, optional): Unused. Defaults to None. **kwargs: Additional keyword arguments. Returns: dict: Cache dict with "lrnn_state" key. """ self.layer.setup_step() state = self.default_state( batch_size, device=next(self.parameters()).device ) return {"lrnn_state": state}
[docs] def default_state(self, *batch_shape, device=None): return self.layer.default_state(*batch_shape)
@property def d_output(self): return self.d_model