"""
Original Mamba SSM Scan operation,
modified from
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
"""
import selective_scan_cuda
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
from causal_conv1d.cpp_functions import (
causal_conv1d_bwd_function,
causal_conv1d_fwd_function,
causal_conv1d_update_function,
)
except ImportError:
causal_conv1d_fn = None
causal_conv1d_fwd_function = None
causal_conv1d_bwd_function = None
causal_conv1d_update_function = None
from lrnnx.ops.torch import custom_bwd, custom_fwd
from lrnnx.ops.triton.layer_norm import _layer_norm_fwd
[docs]
class SelectiveScanFn(torch.autograd.Function):
"""Autograd function for the Mamba Selective Scan CUDA kernel."""
[docs]
@staticmethod
def forward(
ctx,
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
deltaA=None,
delta_softplus=False,
return_last_state=False,
discretization=None,
):
"""
Forward pass of the selective scan.
Args:
ctx (Any): Autograd context.
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``.
A (torch.Tensor): State matrix A of shape ``(dim, dstate)``.
B (torch.Tensor): Input matrix B of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``.
C (torch.Tensor): Output matrix C of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``.
D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None.
z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None.
deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional): Whether to return the final state. Defaults to False.
discretization (str, optional): Discretization method to pass to the kernel. Defaults to None.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state.
"""
# Use is_contiguous() instead of stride(-1) != 1 check, since a tensor
# can have stride(-1) == 1 but still be non-contiguous in other dimensions
if not u.is_contiguous():
u = u.contiguous()
if not delta.is_contiguous():
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if not B.is_contiguous():
B = B.contiguous()
if not C.is_contiguous():
C = C.contiguous()
if z is not None and not z.is_contiguous():
z = z.contiguous()
if deltaA is not None and not deltaA.is_contiguous():
deltaA = deltaA.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
deltaA,
delta_softplus,
discretization,
)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
ctx.has_deltaA = deltaA is not None
ctx.disc_method = discretization
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
is_variable_B = B.dim() >= 3
if not is_variable_B and return_last_state:
# B has shape (dim, dstate), last_state has shape (batch, dim, dstate)
# Broadcasting handles the multiplication correctly.
last_state = last_state * B
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, deltaA, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(
u, delta, A, B, C, D, z, delta_bias, deltaA, x, out
)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
[docs]
@staticmethod
def backward(ctx, dout, *args):
"""
Backward pass for the selective scan.
Args:
ctx (Any): Autograd context.
dout (torch.Tensor): Gradient of the output tensor.
*args: Additional gradients (e.g., for last_state, which are ignored).
Returns:
tuple: Gradients with respect to the inputs.
"""
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, deltaA, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, deltaA, x, out = (
ctx.saved_tensors
)
if dout.stride(-1) != 1:
dout = dout.contiguous()
disc_method = ctx.disc_method
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = (
selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
deltaA,
dout,
x,
out,
None,
ctx.delta_softplus,
False,
disc_method, # option to recompute out_z, not used here
)
)
# Parse rest based on what was passed in
# The CUDA function returns: du, ddelta, dA, dB, dC, dD, ddelta_bias, [ddeltaA,] [dz]
# ddeltaA is present if deltaA is not None
# dz is present if z is not None
rest_idx = 0
ddeltaA = None
dz = None
if ctx.has_deltaA:
ddeltaA = rest[rest_idx] if rest_idx < len(rest) else None
rest_idx += 1
if ctx.has_z:
dz = rest[rest_idx] if rest_idx < len(rest) else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (
du,
ddelta,
dA,
dB,
dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
ddeltaA if deltaA is not None else None,
None,
None,
None,
)
[docs]
def selective_scan_fn(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
deltaA=None,
delta_softplus=False,
return_last_state=False,
discretization="mamba",
):
"""
Apply the CUDA selective scan function.
Args:
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``.
A (torch.Tensor): State matrix A of shape ``(dim, dstate)``.
B (torch.Tensor): Input matrix B of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``.
C (torch.Tensor): Output matrix C of shape ``(batch, dstate, seqlen)`` or ``(dim, dstate)``.
D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None.
z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None.
deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional): If True, returns ``(out, last_state)``.
The last_state has shape ``(batch, dim, dstate)``. Note that the gradient
of the last state is not considered in the backward pass. Defaults to False.
discretization (str, optional): Discretization method. Defaults to "mamba".
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor, and optionally the last state.
"""
return SelectiveScanFn.apply(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
deltaA,
delta_softplus,
return_last_state,
discretization,
)
[docs]
def selective_scan_ref(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
deltaA=None,
delta_softplus=False,
return_last_state=False,
discretization="mamba",
):
"""
Reference (pure PyTorch) implementation of the selective scan.
Args:
u (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
delta (torch.Tensor): Delta tensor of shape ``(batch, dim, seqlen)``.
A (torch.Tensor): State matrix A of shape ``(dim, dstate)``.
B (torch.Tensor): Matrix B. Can be shape ``(dim, dstate)``, ``(batch, dim, seqlen)``, or ``(batch, groups, dstate, seqlen)``.
C (torch.Tensor): Matrix C. Can be shape ``(dim, dstate)``, ``(batch, dim, seqlen)``, or ``(batch, groups, dstate, seqlen)``.
D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)``. Defaults to None.
z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta of shape ``(dim,)``. Defaults to None.
deltaA (torch.Tensor, optional): Asymmetric delta for A of shape ``(batch, dim, seqlen)``. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional): Whether to return the final state. Defaults to False.
discretization (str, optional): Discretization method to use. Defaults to "mamba".
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The output tensor of shape ``(batch, dim, seqlen)``,
and optionally the last state of shape ``(batch, dim, dstate)``.
"""
dtype_in = u.dtype
u_for_scan = u.float()
delta = delta.float()
if delta_bias is not None:
delta = (
delta + delta_bias[..., None].float()
) # delta_bias: (dim) -> (dim, 1)
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(
rearrange(B.float(), "... (L two) -> ... L two", two=2)
)
if is_variable_C:
C = torch.view_as_complex(
rearrange(C.float(), "... (L two) -> ... L two", two=2)
)
else:
B = B.float()
C = C.float()
# 2. Vectorized Discretization Block
# Resulting deltaA and deltaB_u are of shape (batch, dim, L, dstate)
if discretization == "mamba":
# Use deltaA for A discretization if provided, otherwise use delta
delta_for_A = deltaA if deltaA is not None else delta
deltaA_disc = torch.exp(
torch.einsum("bdl,dn->bdln", delta_for_A, A)
) # (b, d, L, n)
if not is_variable_B:
deltaB_u = torch.einsum(
"bdl,dn,bdl->bdln", delta, B, u_for_scan
) # (b, d, L, n)
else:
if B.dim() == 3:
deltaB_u = torch.einsum(
"bdl,bnl,bdl->bdln", delta, B, u_for_scan
)
else:
B_grouped = repeat(
B, "B G N L -> B (G H) N L", H=dim // B.shape[1]
)
deltaB_u = torch.einsum(
"bdl,bdnl,bdl->bdln", delta, B_grouped, u_for_scan
)
else:
u_b_d_l_1 = u_for_scan.unsqueeze(-1) # (b, d, L, 1)
delta_b_d_l_1 = delta.unsqueeze(-1) # (b, d, L, 1)
A_1_d_1_n = A.unsqueeze(0).unsqueeze(2) # (1, d, 1, n)
# Use deltaA for A discretization if provided, otherwise use delta
# deltaA allows for asynchronous/event-based processing where state dynamics
# can evolve with different time constants than the input
delta_for_A_b_d_l_1 = (
deltaA.unsqueeze(-1) if deltaA is not None else delta_b_d_l_1
)
if not is_variable_B:
B_broadcast = B.unsqueeze(0).unsqueeze(2) # (1, d, 1, n)
else:
if B.dim() == 3:
B_broadcast = B.unsqueeze(1).permute(0, 1, 3, 2)
else:
B_broadcast = repeat(
B, "B G N L -> B (G H) N L", H=dim // B.shape[1]
).permute(0, 1, 3, 2)
if discretization == "zoh":
# A discretization uses deltaA if provided
A_del_for_A = delta_for_A_b_d_l_1 * A_1_d_1_n # (b, d, L, n)
deltaA_disc = torch.exp(A_del_for_A) # (b, d, L, n)
# B discretization uses delta (not deltaA)
A_del_for_B = delta_b_d_l_1 * A_1_d_1_n # (b, d, L, n)
dB = torch.expm1(A_del_for_B) / A_1_d_1_n # (b, d, L, n)
deltaB_u = dB * u_b_d_l_1 * B_broadcast
elif discretization == "bilinear":
# A discretization uses deltaA if provided
a_half_for_A = (
0.5 * delta_for_A_b_d_l_1 * A_1_d_1_n
) # (b, d, L, n)
one = torch.tensor(1.0, dtype=A.dtype, device=A.device)
denom_A = one - a_half_for_A
deltaA_disc = (one + a_half_for_A) / denom_A # (b, d, L, n)
# B discretization uses delta (not deltaA)
a_half_for_B = 0.5 * delta_b_d_l_1 * A_1_d_1_n # (b, d, L, n)
denom_B = one - a_half_for_B
dB = delta_b_d_l_1 / denom_B # (b, d, L, n)
deltaB_u = dB * u_b_d_l_1 * B_broadcast
elif discretization == "dirac":
# A discretization uses deltaA if provided
A_del = delta_for_A_b_d_l_1 * A_1_d_1_n # (b, d, L, n)
deltaA_disc = torch.exp(A_del) # (b, d, L, n)
# B discretization: B_bar = B (no delta)
deltaB_u = u_b_d_l_1 * B_broadcast
else:
raise ValueError(
f"Unknown discretization method: {discretization}"
)
# 3. Sequential Scan Loop
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
x = A.new_zeros(
(batch, dim, dstate)
) # x is the hidden state, shape (b, d, n)
ys = []
last_state = None
for i in range(u.shape[2]):
# x' = A_bar * x + B_bar * u
x = deltaA_disc[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C) # y shape (b, d)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # y shape (b, d, L)
# 4. Final Output Processing
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in) # Final out shape (b, d, L)
return out if not return_last_state else (out, last_state)
[docs]
def rms_norm_forward(
x,
weight,
bias,
eps=1e-6,
is_rms_norm=True,
):
"""
Forward pass for RMS normalization.
Args:
x (torch.Tensor): Input tensor of shape ``(batch * seqlen, dim)``.
weight (torch.Tensor): Weight tensor of shape ``(dim,)``.
bias (torch.Tensor | None): Bias tensor of shape ``(dim,)`` or None.
eps (float, optional): Epsilon for numerical stability. Defaults to 1e-6.
is_rms_norm (bool, optional): Whether to use RMS norm (vs Layer norm). Defaults to True.
Returns:
torch.Tensor: Normalized output tensor of shape ``(batch * seqlen, dim)``.
"""
# x (b l) d
if x.stride(-1) != 1:
x = x.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y = _layer_norm_fwd(
x,
weight,
bias,
eps,
None,
residual_dtype=None,
is_rms_norm=is_rms_norm,
)[0]
# y (b l) d
return y
[docs]
class MambaInnerFn(torch.autograd.Function):
"""Autograd function for the fused Mamba inner loop."""
@staticmethod
@custom_fwd
def forward(
ctx,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
"""
Forward pass of the fused Mamba inner function.
Args:
ctx (Any): Autograd context.
xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``.
x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``.
delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``.
out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``.
A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``.
B (torch.Tensor, optional): State matrix B. Defaults to None.
C (torch.Tensor, optional): State matrix C. Defaults to None.
D (torch.Tensor, optional): Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True.
checkpoint_lvl (int, optional): Gradient checkpointing level (0 or 1). Defaults to 1.
b_rms_weight (torch.Tensor, optional): RMS norm weights for B. Defaults to None.
c_rms_weight (torch.Tensor, optional): RMS norm weights for C. Defaults to None.
dt_rms_weight (torch.Tensor, optional): RMS norm weights for dt. Defaults to None.
b_c_dt_rms_eps (float, optional): RMS norm epsilon. Defaults to 1e-6.
Returns:
torch.Tensor: The projected output tensor.
"""
assert (
causal_conv1d_fwd_function is not None
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(
dtype=torch.get_autocast_gpu_dtype()
)
delta_proj_weight = delta_proj_weight.to(
dtype=torch.get_autocast_gpu_dtype()
)
out_proj_weight = out_proj_weight.to(
dtype=torch.get_autocast_gpu_dtype()
)
out_proj_bias = (
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
if out_proj_bias is not None
else None
)
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = (
conv1d_bias.contiguous() if conv1d_bias is not None else None
)
conv1d_out = causal_conv1d_fwd_function(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
) # (bl d)
delta = rearrange(
delta_proj_weight @ x_dbl[:, :delta_rank].t(),
"d (b l) -> b d l",
l=L,
)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(
B, "(b l) dstate -> b 1 dstate l", l=L
).contiguous()
else:
B = rearrange(
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(
C, "(b l) dstate -> b 1 dstate l", l=L
).contiguous()
else:
C = rearrange(
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
if b_rms_weight is not None:
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(
B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps
)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(
C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps
)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
ctx.b_rms_weight = b_rms_weight
ctx.c_rms_weight = c_rms_weight
ctx.dt_rms_weight = dt_rms_weight
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
if (
checkpoint_lvl >= 1
): # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
)
return F.linear(
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
)
@staticmethod
@custom_bwd
def backward(ctx, dout):
"""
Backward pass for the fused Mamba inner function.
Args:
ctx (Any): Autograd context.
dout (torch.Tensor): Gradient of the output tensor. Shape ``(batch, seqlen, d_model)``.
Returns:
tuple: Gradients with respect to inputs.
"""
# dout: (batch, seqlen, dim)
assert (
causal_conv1d_fwd_function is not None
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_fwd_function(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
delta = rearrange(
delta_proj_weight @ x_dbl[:, :delta_rank].t(),
"d (b l) -> b d l",
l=L,
)
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
if b_rms_weight is not None:
# Recompute & RMSNorm B
B = rearrange(
B, "b 1 dstate l -> (b l) dstate", l=L
).contiguous()
B = rms_norm_forward(
B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
)
B = rearrange(
B, "(b l) dstate -> b 1 dstate l", l=L
).contiguous()
if c_rms_weight is not None:
# Recompute & RMSNorm C
C = rearrange(
C, "b 1 dstate l -> (b l) dstate", l=L
).contiguous()
C = rms_norm_forward(
C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
)
C = rearrange(
C, "(b l) dstate -> b 1 dstate l", l=L
).contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
selective_scan_cuda.bwd(
conv1d_out,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout_y,
scan_intermediates,
out,
dz,
ctx.delta_softplus,
True, # option to recompute out_z
)
)
dout_proj_weight = torch.einsum(
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
)
dout_proj_bias = (
dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
)
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum(
"dB,Br->dr", ddelta, x_dbl[:, :delta_rank]
)
dx_dbl[:, :delta_rank] = torch.einsum(
"dB,dr->Br", ddelta, delta_proj_weight
)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum(
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
)
dconv1d_out = torch.addmm(
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
)
dconv1d_out = rearrange(
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
)
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(
x,
conv1d_weight,
conv1d_bias,
dconv1d_out,
None,
None,
None,
dx,
False,
True,
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (
dxz,
dconv1d_weight,
dconv1d_bias,
dx_proj_weight,
ddelta_proj_weight,
dout_proj_weight,
dout_proj_bias,
dA,
dB,
dC,
dD,
ddelta_bias if delta_bias is not None else None,
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
dB_proj_bias,
dC_proj_bias,
None,
None,
None,
None,
None,
None,
)
[docs]
def mamba_inner_fn(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
"""
Apply the fused Mamba inner function.
Args:
xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``.
x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``.
delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``.
out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``.
A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``.
B (torch.Tensor, optional): State matrix B. Defaults to None.
C (torch.Tensor, optional): State matrix C. Defaults to None.
D (torch.Tensor, optional): Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True.
checkpoint_lvl (int, optional): Gradient checkpointing level (0 or 1). Defaults to 1.
b_rms_weight (torch.Tensor, optional): RMS norm weights for B. Defaults to None.
c_rms_weight (torch.Tensor, optional): RMS norm weights for C. Defaults to None.
dt_rms_weight (torch.Tensor, optional): RMS norm weights for dt. Defaults to None.
b_c_dt_rms_eps (float, optional): RMS norm epsilon. Defaults to 1e-6.
Returns:
torch.Tensor: The projected output tensor.
"""
return MambaInnerFn.apply(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B,
C,
D,
delta_bias,
B_proj_bias,
C_proj_bias,
delta_softplus,
checkpoint_lvl,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
b_c_dt_rms_eps,
)
[docs]
def mamba_inner_ref(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
):
"""
Reference (pure PyTorch) implementation of the Mamba inner function.
Args:
xz (torch.Tensor): Input tensor of shape ``(batch, dim, seqlen)``.
conv1d_weight (torch.Tensor): Conv1d weights of shape ``(dim, 1, kernel_size)``.
conv1d_bias (torch.Tensor | None): Conv1d biases of shape ``(dim,)``.
x_proj_weight (torch.Tensor): Projection weights for B, C, delta. Shape ``(delta_rank + 2*dstate, dim)``.
delta_proj_weight (torch.Tensor): Projection weights for delta. Shape ``(dim, delta_rank)``.
out_proj_weight (torch.Tensor): Output projection weights. Shape ``(d_model, dim)``.
out_proj_bias (torch.Tensor | None): Output projection biases. Shape ``(d_model,)``.
A (torch.Tensor): State matrix A. Shape ``(dim, dstate)``.
B (torch.Tensor, optional): State matrix B. Defaults to None.
C (torch.Tensor, optional): State matrix C. Defaults to None.
D (torch.Tensor, optional): Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional): Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional): Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional): Bias for C projection. Defaults to None.
delta_softplus (bool, optional): Whether to apply softplus to delta. Defaults to True.
Returns:
torch.Tensor: The projected output tensor.
"""
assert (
causal_conv1d_fn is not None
), "causal_conv1d_fn is not available. Please install causal-conv1d."
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
x = causal_conv1d_fn(
x,
rearrange(conv1d_weight, "d 1 w -> d w"),
conv1d_bias,
activation="silu",
)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
delta = rearrange(delta, "d (b l) -> b d l", l=L)
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
B = rearrange(
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
).contiguous()
if C is None: # variable B
C = x_dbl[:, -d_state:] # (bl d)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
C = rearrange(
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
).contiguous()
y = selective_scan_fn(
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
)
return F.linear(
rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias
)