Source code for lrnnx.models.ltv.mamba

# Copyright (c) 2023, Tri Dao, Albert Gu.
# Modified to incorporate different discretizations and event-based processing.
# Reference: https://github.com/Efficient-Scalable-Machine-Learning/event-based-mamba.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from lrnnx.ops.selective_scan import mamba_inner_fn, selective_scan_fn

try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None, None  # type: ignore[assignment]

try:
    from lrnnx.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None  # type: ignore[assignment]

from typing import Any, Dict, Optional, Tuple, Union

from lrnnx.models.ltv.base import LTV_LRNN


[docs] class Mamba(LTV_LRNN): """ Mamba: Selective State Space Model with optional event-based processing. When integration_timesteps is provided in forward(), uses asymmetric discretization (separate dtA and dtB) for event-driven processing. Otherwise, uses standard Mamba discretization. Example: >>> model = Mamba(d_model=64, d_state=16, d_conv=4) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64]) """
[docs] def __init__( self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=True, # Fused kernel options layer_idx=None, device=None, dtype=None, discretization="mamba", ): """ Initialize Mamba model. Args: d_model (int): Model dimension. d_state (int, optional): SSM state dimension (N). Defaults to 16. d_conv (int, optional): Convolution kernel size. Defaults to 4. expand (int, optional): Expansion factor for inner dimension. Defaults to 2. dt_rank (Union[int, str], optional): Rank for delta projection, ``"auto"`` = ``ceil(d_model / 16)``. Defaults to ``"auto"``. dt_min (float, optional): Minimum value for delta initialization. Defaults to 0.001. dt_max (float, optional): Maximum value for delta initialization. Defaults to 0.1. dt_init (str, optional): Initialization method (``"random"`` or ``"constant"``). Defaults to ``"random"``. dt_scale (float, optional): Scale factor for dt initialization. Defaults to 1.0. dt_init_floor (float, optional): Floor value for dt initialization. Defaults to 1e-4. conv_bias (bool, optional): Whether to use bias in convolution. Defaults to True. bias (bool, optional): Whether to use bias in linear projections. Defaults to False. use_fast_path (bool, optional): Whether to use fused CUDA kernels. Defaults to True. layer_idx (int, optional): Layer index for multi-layer caching. Defaults to None. device (torch.device, optional): Device for parameters. Defaults to None. dtype (torch.dtype, optional): Data type for parameters. Defaults to None. discretization (str, optional): Discretization type. Defaults to ``"mamba"``. """ # pass None to base class since Mamba handles discretization via CUDA kernel, not discretize_fn super().__init__(discretization=None) factory_kwargs = {"device": device, "dtype": dtype} self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = ( math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank ) self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.discretization = discretization self.in_proj = nn.Linear( self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs ) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) self.activation = "silu" self.act = nn.SiLU() self.x_proj = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs, ) # for standard Mamba: single dt_proj (used as dtB in event mode) self.dt_proj = nn.Linear( self.dt_rank, self.d_inner, bias=True, **factory_kwargs ) # for event mode: separate dtA_proj self.dtA_proj = nn.Linear( self.dt_rank, self.d_inner, bias=True, **factory_kwargs ) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True # S4D real initialization A = repeat( torch.arange( 1, self.d_state + 1, dtype=torch.float32, device=device ), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter( torch.ones(self.d_inner, device=device) ) # Keep in fp32 self.D._no_weight_decay = True self.out_proj = nn.Linear( self.d_inner, self.d_model, bias=bias, **factory_kwargs )
[docs] def forward( self, hidden_states, integration_timesteps: Optional[torch.Tensor] = None, lengths: Optional[torch.Tensor] = None, inference_cache: Optional[Dict[str, Any]] = None, ): """ Forward pass through Mamba. Args: hidden_states (torch.Tensor): Input tensor, shape ``(B, L, D)``. integration_timesteps (torch.Tensor, optional): Time intervals between events. Shape ``(B, L)``. When provided, uses asymmetric discretization with separate dtA and dtB for event-driven processing. Defaults to None. lengths (torch.Tensor, optional): Not used by Mamba currently. Defaults to None. inference_cache (dict, optional): Cache for autoregressive generation. If provided, contains "conv_state" and "lrnn_state" tensors. Defaults to None. Returns: torch.Tensor: Output tensor, shape ``(B, L, D)``. """ batch, seqlen, dim = hidden_states.shape # event mode: use asymmetric discretization use_event_mode = integration_timesteps is not None conv_state, ssm_state = None, None if inference_cache is not None: conv_state = inference_cache.get("conv_state") ssm_state = inference_cache.get("lrnn_state") seqlen_offset = inference_cache.get("seqlen_offset", 0) if seqlen_offset > 0: # Use step() for autoregressive decoding # inference_cache is updated in-place out, inference_cache = self.step( hidden_states, inference_cache, integration_timesteps=integration_timesteps, ) return out # We do matmul and transpose BLH -> HBL at the same time xz = rearrange( self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), "d (b l) -> b d l", l=seqlen, ) if self.in_proj.bias is not None: xz = xz + rearrange( self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1" ) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat if ( self.use_fast_path and causal_conv1d_fn is not None and inference_cache is None and not use_event_mode # can't use fast path for event mode ): # Doesn't support outputting the states out = mamba_inner_fn( xz, self.conv1d.weight, self.conv1d.bias, self.x_proj.weight, self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, A, None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) else: x, z = xz.chunk(2, dim=1) # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_( F.pad(x, (self.d_conv - x.shape[-1], 0)) ) # Update state (B D W) if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) B = x_dbl[:, self.dt_rank : self.dt_rank + self.d_state] B = rearrange( B, "(b l) dstate -> b dstate l", l=seqlen ).contiguous() C = x_dbl[:, -self.d_state :] C = rearrange( C, "(b l) dstate -> b dstate l", l=seqlen ).contiguous() # compute dt from x_dbl dt, _, _ = torch.split( x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) if use_event_mode: # event mode: asymmetric discretization with separate dtA and dtB # compute dtA from dtA_proj bias scaled by integration timesteps dtA = repeat( self.dtA_proj.bias, "d -> b d l", b=batch, l=seqlen ) dtA = integration_timesteps.unsqueeze(1) * F.softplus(dtA) # type: ignore[union-attr] # apply softplus to dt (with bias) for B discretization dt = F.softplus(dt + self.dt_proj.bias.float()[:, None]) # run selective scan with separate deltaA for asymmetric discretization assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=None, deltaA=dtA, delta_softplus=False, return_last_state=ssm_state is not None, discretization=self.discretization, ) else: # standard Mamba mode assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, discretization=self.discretization, ) if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out
[docs] def step( self, x: torch.Tensor, inference_cache: Dict[str, Any], integration_timesteps: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Performs a single recurrent step of Mamba. Args: x (torch.Tensor): Input at current timestep, shape ``(B, 1, D)``. inference_cache (Dict[str, Any]): Cache dictionary containing: - "conv_state": Convolution state, shape ``(B, D_inner, d_conv)`` - "lrnn_state": SSM state, shape ``(B, D_inner, N)`` - "seqlen_offset": Current position in sequence integration_timesteps (torch.Tensor, optional): Integration timestep, shape ``(B, 1)`` or ``(B,)``. When provided, uses event-based asymmetric discretization. Defaults to None. **kwargs: Additional keyword arguments. Returns: tuple[torch.Tensor, Dict[str, Any]]: A tuple containing: - out : Output at current timestep, shape ``(B, 1, D)``. - inference_cache : Updated cache dictionary. """ conv_state = inference_cache["conv_state"] ssm_state = inference_cache["lrnn_state"] hidden_states = x use_event_mode = integration_timesteps is not None dtype = hidden_states.dtype assert ( hidden_states.shape[1] == 1 ), "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) x, z = xz.chunk(2, dim=-1) # (B D) # Conv step if causal_conv1d_update is None: conv_state.copy_( torch.roll(conv_state, shifts=-1, dims=-1) ) # Update state (B D W) conv_state[:, :, -1] = x x = torch.sum( conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1, ) # (B D) if self.conv1d.bias is not None: x = x + self.conv1d.bias x = self.act(x).to(dtype=dtype) else: x = causal_conv1d_update( x, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) x_db = self.x_proj(x) # (B dt_rank+2*d_state) B = x_db[:, self.dt_rank : self.dt_rank + self.d_state] C = x_db[:, -self.d_state :] A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # compute dt for B matrix discretization dt = x_db[:, : self.dt_rank] # don't add dt_bias here dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) # compute deltaA for event mode deltaA = None if use_event_mode: # compute dtA from dtA_proj bias scaled by integration timesteps dtA = self.dtA_proj.bias.expand(x.shape[0], -1) timestep = ( integration_timesteps.view(-1, 1) # type: ignore[union-attr] if integration_timesteps.dim() > 1 # type: ignore[union-attr] else integration_timesteps.unsqueeze(-1) # type: ignore[union-attr] ) deltaA = timestep * F.softplus(dtA) if selective_state_update is not None: y = selective_state_update( ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True, deltaA=deltaA, discretization=self.discretization, ) else: dt_with_bias = F.softplus( dt + self.dt_proj.bias.to(dtype=dt.dtype) ) # discretize A: use deltaA if provided, otherwise use dt if deltaA is not None: dA = torch.exp(torch.einsum("bd,dn->bdn", deltaA, A)) else: dA = torch.exp(torch.einsum("bd,dn->bdn", dt_with_bias, A)) # discretize B based on discretization method if self.discretization == "zoh": A_dt = torch.einsum("bd,dn->bdn", dt_with_bias, A) expm1_A_dt = torch.exp(A_dt) - 1.0 B_tilde = expm1_A_dt / A.unsqueeze(0) dB = B.unsqueeze(1) * B_tilde elif self.discretization == "bilinear": v = 0.5 * torch.einsum("bd,dn->bdn", dt_with_bias, A) den_inv = 1.0 / (1.0 - v) dB = B.unsqueeze(1) * den_inv * dt_with_bias.unsqueeze(-1) elif self.discretization == "dirac": dB = B.unsqueeze(1).expand_as(dA) else: # "mamba" dB = torch.einsum("bd,bn->bdn", dt_with_bias, B) ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) y = y + self.D.to(dtype) * x y = y * self.act(z) out = self.out_proj(y) # Update cache in-place and return inference_cache["conv_state"] = conv_state inference_cache["lrnn_state"] = ssm_state inference_cache["seqlen_offset"] = ( inference_cache.get("seqlen_offset", 0) + 1 ) return out.unsqueeze(1), inference_cache
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int, dtype: Optional[torch.dtype] = None, **kwargs, ) -> Dict[str, Any]: """ Allocates cache for Mamba autoregressive inference. Args: batch_size (int): The batch size for inference. max_seqlen (int): Maximum sequence length (not used by Mamba, but kept for interface consistency). dtype (torch.dtype, optional): Data type for allocated tensors. Defaults to None. **kwargs: Additional arguments (unused). Returns: Dict[str, Any]: Cache dictionary containing: - "conv_state": Convolution state, shape ``(B, D_inner, d_conv)``. - "lrnn_state": SSM state, shape ``(B, D_inner, N)``. - "seqlen_offset": Current position in the sequence (starts at 0). """ device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype, ) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # ssm_dtype = torch.float32 ssm_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype, ) return { "conv_state": conv_state, "lrnn_state": ssm_state, "seqlen_offset": 0, }