Source code for lrnnx.ops.selective_scan

"""
Original Mamba SSM Scan operation,
modified from
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
"""

import selective_scan_cuda
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

try:
    from causal_conv1d import causal_conv1d_fn
    from causal_conv1d.cpp_functions import (
        causal_conv1d_bwd_function,
        causal_conv1d_fwd_function,
        causal_conv1d_update_function,
    )
except ImportError:
    causal_conv1d_fn = None
    causal_conv1d_fwd_function = None
    causal_conv1d_bwd_function = None
    causal_conv1d_update_function = None

from lrnnx.ops.torch import custom_bwd, custom_fwd
from lrnnx.ops.triton.layer_norm import _layer_norm_fwd


[docs] class SelectiveScanFn(torch.autograd.Function): """Autograd function for the Mamba Selective Scan CUDA kernel."""
[docs] @staticmethod def forward( ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization=None, ): """ Forward pass of the selective scan. Args: ctx (Any): Autograd context. u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``. A (torch.Tensor): State matrix A of shape ``(dim, dstate)``. B (torch.Tensor): Input matrix B of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``. C (torch.Tensor): Output matrix C of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``. D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None. z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None. deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False. return_last_state (bool, optional): Whether to return the final state. Defaults to False. discretization (str, optional): Discretization method to pass to the kernel. Defaults to None. Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state. """ # Use is_contiguous() instead of stride(-1) != 1 check, since a tensor # can have stride(-1) == 1 but still be non-contiguous in other dimensions if not u.is_contiguous(): u = u.contiguous() if not delta.is_contiguous(): delta = delta.contiguous() if D is not None: D = D.contiguous() if not B.is_contiguous(): B = B.contiguous() if not C.is_contiguous(): C = C.contiguous() if z is not None and not z.is_contiguous(): z = z.contiguous() if deltaA is not None and not deltaA.is_contiguous(): deltaA = deltaA.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd( u, delta, A, B, C, D, z, delta_bias, deltaA, delta_softplus, discretization, ) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None ctx.has_deltaA = deltaA is not None ctx.disc_method = discretization last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) is_variable_B = B.dim() >= 3 if not is_variable_B and return_last_state: # B has shape (dim, dstate), last_state has shape (batch, dim, dstate) # Broadcasting handles the multiplication correctly. last_state = last_state * B if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, deltaA, x) return out if not return_last_state else (out, last_state) else: ctx.save_for_backward( u, delta, A, B, C, D, z, delta_bias, deltaA, x, out ) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state)
[docs] @staticmethod def backward(ctx, dout, *args): """ Backward pass for the selective scan. Args: ctx (Any): Autograd context. dout (torch.Tensor): Gradient of the output tensor. *args: Additional gradients (e.g., for last_state, which are ignored). Returns: tuple: Gradients with respect to the inputs. """ if not ctx.has_z: u, delta, A, B, C, D, delta_bias, deltaA, x = ctx.saved_tensors z = None out = None else: u, delta, A, B, C, D, z, delta_bias, deltaA, x, out = ( ctx.saved_tensors ) if dout.stride(-1) != 1: dout = dout.contiguous() disc_method = ctx.disc_method du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ( selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, deltaA, dout, x, out, None, ctx.delta_softplus, False, disc_method, # option to recompute out_z, not used here ) ) # Parse rest based on what was passed in # The CUDA function returns: du, ddelta, dA, dB, dC, dD, ddelta_bias, [ddeltaA,] [dz] # ddeltaA is present if deltaA is not None # dz is present if z is not None rest_idx = 0 ddeltaA = None dz = None if ctx.has_deltaA: ddeltaA = rest[rest_idx] if rest_idx < len(rest) else None rest_idx += 1 if ctx.has_z: dz = rest[rest_idx] if rest_idx < len(rest) else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return ( du, ddelta, dA, dB, dC, dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, ddeltaA if deltaA is not None else None, None, None, None, )
[docs] def selective_scan_fn( u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization="mamba", ): """ Apply the CUDA selective scan function. Args: u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``. A (torch.Tensor): State matrix A of shape ``(dim, dstate)``. B (torch.Tensor): Input matrix B of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``. C (torch.Tensor): Output matrix C of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``. D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None. z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None. deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False. return_last_state (bool, optional): If True, returns ``(out, last_state)``. The last_state has shape ``(batch, dim, dstate)``. Note that the gradient of the last state is not considered in the backward pass. Defaults to False. discretization (str, optional): Discretization method. Defaults to "mamba". Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state. """ return SelectiveScanFn.apply( u, delta, A, B, C, D, z, delta_bias, deltaA, delta_softplus, return_last_state, discretization, )
[docs] def selective_scan_ref( u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization="mamba", ): """ Reference (pure PyTorch) implementation of the selective scan. Args: u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``. A (torch.Tensor): State matrix A of shape ``(dim, dstate)``. B (torch.Tensor): Matrix B. Can be shape ``(dim, dstate)``, ``(batch, dim, seqlen)``, or ``(batch, groups, dstate, seqlen)``. C (torch.Tensor): Matrix C. Can be shape ``(dim, dstate)``, ``(batch, dim, seqlen)``, or ``(batch, groups, dstate, seqlen)``. D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None. z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None. deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False. return_last_state (bool, optional): Whether to return the final state. Defaults to False. discretization (str, optional): Discretization method to use. Defaults to "mamba". Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor of shape ``(batch, dim, seqlen)``, and optionally the last state of shape ``(batch, dim, dstate)``. """ dtype_in = u.dtype u_for_scan = u.float() delta = delta.float() if delta_bias is not None: delta = ( delta + delta_bias[..., None].float() ) # delta_bias: (dim) -> (dim, 1) if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex( rearrange(B.float(), "... (L two) -> ... L two", two=2) ) if is_variable_C: C = torch.view_as_complex( rearrange(C.float(), "... (L two) -> ... L two", two=2) ) else: B = B.float() C = C.float() # 2. Vectorized Discretization Block # Resulting deltaA and deltaB_u are of shape (batch, dim, L, dstate) if discretization == "mamba": # Use deltaA for A discretization if provided, otherwise use delta delta_for_A = deltaA if deltaA is not None else delta deltaA_disc = torch.exp( torch.einsum("bdl,dn->bdln", delta_for_A, A) ) # (b, d, L, n) if not is_variable_B: deltaB_u = torch.einsum( "bdl,dn,bdl->bdln", delta, B, u_for_scan ) # (b, d, L, n) else: if B.dim() == 3: deltaB_u = torch.einsum( "bdl,bnl,bdl->bdln", delta, B, u_for_scan ) else: B_grouped = repeat( B, "B G N L -> B (G H) N L", H=dim // B.shape[1] ) deltaB_u = torch.einsum( "bdl,bdnl,bdl->bdln", delta, B_grouped, u_for_scan ) else: u_b_d_l_1 = u_for_scan.unsqueeze(-1) # (b, d, L, 1) delta_b_d_l_1 = delta.unsqueeze(-1) # (b, d, L, 1) A_1_d_1_n = A.unsqueeze(0).unsqueeze(2) # (1, d, 1, n) # Use deltaA for A discretization if provided, otherwise use delta # deltaA allows for asynchronous/event-based processing where state dynamics # can evolve with different time constants than the input delta_for_A_b_d_l_1 = ( deltaA.unsqueeze(-1) if deltaA is not None else delta_b_d_l_1 ) if not is_variable_B: B_broadcast = B.unsqueeze(0).unsqueeze(2) # (1, d, 1, n) else: if B.dim() == 3: B_broadcast = B.unsqueeze(1).permute(0, 1, 3, 2) else: B_broadcast = repeat( B, "B G N L -> B (G H) N L", H=dim // B.shape[1] ).permute(0, 1, 3, 2) if discretization == "zoh": # A discretization uses deltaA if provided A_del_for_A = delta_for_A_b_d_l_1 * A_1_d_1_n # (b, d, L, n) deltaA_disc = torch.exp(A_del_for_A) # (b, d, L, n) # B discretization uses delta (not deltaA) A_del_for_B = delta_b_d_l_1 * A_1_d_1_n # (b, d, L, n) dB = torch.expm1(A_del_for_B) / A_1_d_1_n # (b, d, L, n) deltaB_u = dB * u_b_d_l_1 * B_broadcast elif discretization == "bilinear": # A discretization uses deltaA if provided a_half_for_A = ( 0.5 * delta_for_A_b_d_l_1 * A_1_d_1_n ) # (b, d, L, n) one = torch.tensor(1.0, dtype=A.dtype, device=A.device) denom_A = one - a_half_for_A deltaA_disc = (one + a_half_for_A) / denom_A # (b, d, L, n) # B discretization uses delta (not deltaA) a_half_for_B = 0.5 * delta_b_d_l_1 * A_1_d_1_n # (b, d, L, n) denom_B = one - a_half_for_B dB = delta_b_d_l_1 / denom_B # (b, d, L, n) deltaB_u = dB * u_b_d_l_1 * B_broadcast elif discretization == "dirac": # A discretization uses deltaA if provided A_del = delta_for_A_b_d_l_1 * A_1_d_1_n # (b, d, L, n) deltaA_disc = torch.exp(A_del) # (b, d, L, n) # B discretization: B_bar = B (no delta) deltaB_u = u_b_d_l_1 * B_broadcast else: raise ValueError( f"Unknown discretization method: {discretization}" ) # 3. Sequential Scan Loop if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) x = A.new_zeros( (batch, dim, dstate) ) # x is the hidden state, shape (b, d, n) ys = [] last_state = None for i in range(u.shape[2]): # x' = A_bar * x + B_bar * u x = deltaA_disc[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum("bdn,dn->bd", x, C) # y shape (b, d) else: if C.dim() == 3: y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) else: y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # y shape (b, d, L) # 4. Final Output Processing out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) # Final out shape (b, d, L) return out if not return_last_state else (out, last_state)
[docs] def rms_norm_forward( x, weight, bias, eps=1e-6, is_rms_norm=True, ): """ Forward pass for RMS normalization. Args: x (torch.Tensor): Input tensor of shape ``(batch * seqlen, dim)``. weight (torch.Tensor): Weight tensor of shape ``(dim,)``. bias (torch.Tensor | None): Bias tensor of shape ``(dim,)`` or None. eps (float, optional): Epsilon for numerical stability. Defaults to 1e-6. is_rms_norm (bool, optional): Whether to use RMS norm (vs Layer norm). Defaults to True. Returns: torch.Tensor: Normalized output tensor of shape ``(batch * seqlen, dim)``. """ # x (b l) d if x.stride(-1) != 1: x = x.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() y = _layer_norm_fwd( x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm, )[0] # y (b l) d return y
[docs] class MambaInnerFn(torch.autograd.Function): """Autograd function for the fused Mamba inner loop.""" @staticmethod @custom_fwd def forward( ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6, ): """ Forward pass of the fused Mamba inner function. Args: ctx (Any): Autograd context. xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``. conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``. x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``. delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``. out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``. out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``. A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``. B (torch.Tensor, optional): State matrix B. Defaults to None. C (torch.Tensor, optional): State matrix C. Defaults to None. D (torch.Tensor, optional): Skip connection matrix D. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None. B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None. C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True. checkpoint_lvl (int, optional): Gradient checkpointing level (0 or 1). Defaults to 1. b_rms_weight (torch.Tensor, optional): RMS norm weights for B. Defaults to None. c_rms_weight (torch.Tensor, optional): RMS norm weights for C. Defaults to None. dt_rms_weight (torch.Tensor, optional): RMS norm weights for dt. Defaults to None. b_c_dt_rms_eps (float, optional): RMS norm epsilon. Defaults to 1e-6. Returns: torch.Tensor: The projected output tensor. """ assert ( causal_conv1d_fwd_function is not None ), "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to( dtype=torch.get_autocast_gpu_dtype() ) delta_proj_weight = delta_proj_weight.to( dtype=torch.get_autocast_gpu_dtype() ) out_proj_weight = out_proj_weight.to( dtype=torch.get_autocast_gpu_dtype() ) out_proj_bias = ( out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None ) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = ( conv1d_bias.contiguous() if conv1d_bias is not None else None ) conv1d_out = causal_conv1d_fwd_function( x, conv1d_weight, conv1d_bias, None, None, None, True ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear( rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight ) # (bl d) delta = rearrange( delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L, ) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange( B, "(b l) dstate -> b 1 dstate l", l=L ).contiguous() else: B = rearrange( B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 ).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange( C, "(b l) dstate -> b 1 dstate l", l=L ).contiguous() else: C = rearrange( C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 ).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() if b_rms_weight is not None: B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() B = rms_norm_forward( B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps ) B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() if c_rms_weight is not None: C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() C = rms_norm_forward( C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps ) C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() if dt_rms_weight is not None: delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward( delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps ) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl ctx.b_rms_weight = b_rms_weight ctx.c_rms_weight = c_rms_weight ctx.dt_rms_weight = dt_rms_weight ctx.b_c_dt_rms_eps = b_c_dt_rms_eps if ( checkpoint_lvl >= 1 ): # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward( xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, ) return F.linear( rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias ) @staticmethod @custom_bwd def backward(ctx, dout): """ Backward pass for the fused Mamba inner function. Args: ctx (Any): Autograd context. dout (torch.Tensor): Gradient of the output tensor. Shape ``(batch, seqlen, d_model)``. Returns: tuple: Gradients with respect to inputs. """ # dout: (batch, seqlen, dim) assert ( causal_conv1d_fwd_function is not None ), "causal_conv1d_cuda is not available. Please install causal-conv1d." ( xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, ) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_fwd_function( x, conv1d_weight, conv1d_bias, None, None, None, True ) delta = rearrange( delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L, ) if dt_rms_weight is not None: delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward( delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps ) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() if b_rms_weight is not None: # Recompute & RMSNorm B B = rearrange( B, "b 1 dstate l -> (b l) dstate", l=L ).contiguous() B = rms_norm_forward( B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps ) B = rearrange( B, "(b l) dstate -> b 1 dstate l", l=L ).contiguous() if c_rms_weight is not None: # Recompute & RMSNorm C C = rearrange( C, "b 1 dstate l -> (b l) dstate", l=L ).contiguous() C = rms_norm_forward( C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps ) C = rearrange( C, "(b l) dstate -> b 1 dstate l", l=L ).contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = ( selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, True, # option to recompute out_z ) ) dout_proj_weight = torch.einsum( "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)") ) dout_proj_bias = ( dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None ) dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange( dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 ).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange( dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 ).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum( "dB,Br->dr", ddelta, x_dbl[:, :delta_rank] ) dx_dbl[:, :delta_rank] = torch.einsum( "dB,dr->Br", ddelta, delta_proj_weight ) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum( "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d") ) dconv1d_out = torch.addmm( dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out ) dconv1d_out = rearrange( dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1] ) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function( x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True, ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return ( dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps dB_proj_bias, dC_proj_bias, None, None, None, None, None, None, )
[docs] def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6, ): """ Apply the fused Mamba inner function. Args: xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``. conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``. x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``. delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``. out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``. out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``. A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``. B (torch.Tensor, optional): State matrix B. Defaults to None. C (torch.Tensor, optional): State matrix C. Defaults to None. D (torch.Tensor, optional): Skip connection matrix D. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None. B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None. C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True. checkpoint_lvl (int, optional): Gradient checkpointing level (0 or 1). Defaults to 1. b_rms_weight (torch.Tensor, optional): RMS norm weights for B. Defaults to None. c_rms_weight (torch.Tensor, optional): RMS norm weights for C. Defaults to None. dt_rms_weight (torch.Tensor, optional): RMS norm weights for dt. Defaults to None. b_c_dt_rms_eps (float, optional): RMS norm epsilon. Defaults to 1e-6. Returns: torch.Tensor: The projected output tensor. """ return MambaInnerFn.apply( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps, )
[docs] def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, ): """ Reference (pure PyTorch) implementation of the Mamba inner function. Args: xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``. conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``. conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``. x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``. delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``. out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``. out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``. A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``. B (torch.Tensor, optional): State matrix B. Defaults to None. C (torch.Tensor, optional): State matrix C. Defaults to None. D (torch.Tensor, optional): Skip connection matrix D. Defaults to None. delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None. B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None. C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None. delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True. Returns: torch.Tensor: The projected output tensor. """ assert ( causal_conv1d_fn is not None ), "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn( x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu", ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange( B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 ).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange( C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 ).contiguous() y = selective_scan_fn( x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True ) return F.linear( rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias )