lrnnx.ops.selective_scan module¶
Original Mamba SSM Scan operation, modified from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
- class SelectiveScanFn[source]¶
Bases:
FunctionAutograd function for the Mamba Selective Scan CUDA kernel.
- static forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization=None)[source]¶
Forward pass of the selective scan.
- Parameters:
ctx (Any) – Autograd context.
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).delta (torch.Tensor) – Delta tensor of shape
(batch, dim, seqlen).A (torch.Tensor) – State matrix A of shape
(dim, dstate).B (torch.Tensor) – Input matrix B of shape
(batch, dstate, seqlen)or(dim, dstate).C (torch.Tensor) – Output matrix C of shape
(batch, dstate, seqlen)or(dim, dstate).D (torch.Tensor, optional) – Skip connection vector of shape
(dim,). Defaults to None.z (torch.Tensor, optional) – Gating tensor of shape
(batch, dim, seqlen). Defaults to None.delta_bias (torch.Tensor, optional) – Bias for delta of shape
(dim,). Defaults to None.deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape
(batch, dim, seqlen). Defaults to None.delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional) – Whether to return the final state. Defaults to False.
discretization (str, optional) – Discretization method to pass to the kernel. Defaults to None.
- Returns:
The output tensor, and optionally the last state.
- Return type:
- static backward(ctx, dout)[source]¶
Backward pass for the selective scan.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor.
*args – Additional gradients (e.g., for last_state, which are ignored).
- Returns:
Gradients with respect to the inputs.
- Return type:
- selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization='mamba')[source]¶
Apply the CUDA selective scan function.
- Parameters:
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).delta (torch.Tensor) – Delta tensor of shape
(batch, dim, seqlen).A (torch.Tensor) – State matrix A of shape
(dim, dstate).B (torch.Tensor) – Input matrix B of shape
(batch, dstate, seqlen)or(dim, dstate).C (torch.Tensor) – Output matrix C of shape
(batch, dstate, seqlen)or(dim, dstate).D (torch.Tensor, optional) – Skip connection vector of shape
(dim,). Defaults to None.z (torch.Tensor, optional) – Gating tensor of shape
(batch, dim, seqlen). Defaults to None.delta_bias (torch.Tensor, optional) – Bias for delta of shape
(dim,). Defaults to None.deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape
(batch, dim, seqlen). Defaults to None.delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional) – If True, returns
(out, last_state). The last_state has shape(batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. Defaults to False.discretization (str, optional) – Discretization method. Defaults to “mamba”.
- Returns:
The output tensor, and optionally the last state.
- Return type:
- selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization='mamba')[source]¶
Reference (pure PyTorch) implementation of the selective scan.
- Parameters:
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).delta (torch.Tensor) – Delta tensor of shape
(batch, dim, seqlen).A (torch.Tensor) – State matrix A of shape
(dim, dstate).B (torch.Tensor) – Matrix B. Can be shape
(dim, dstate),(batch, dim, seqlen), or(batch, groups, dstate, seqlen).C (torch.Tensor) – Matrix C. Can be shape
(dim, dstate),(batch, dim, seqlen), or(batch, groups, dstate, seqlen).D (torch.Tensor, optional) – Skip connection vector of shape
(dim,). Defaults to None.z (torch.Tensor, optional) – Gating tensor of shape
(batch, dim, seqlen). Defaults to None.delta_bias (torch.Tensor, optional) – Bias for delta of shape
(dim,). Defaults to None.deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape
(batch, dim, seqlen). Defaults to None.delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.
return_last_state (bool, optional) – Whether to return the final state. Defaults to False.
discretization (str, optional) – Discretization method to use. Defaults to “mamba”.
- Returns:
- The output tensor of shape
(batch, dim, seqlen), and optionally the last state of shape
(batch, dim, dstate).
- The output tensor of shape
- Return type:
- rms_norm_forward(x, weight, bias, eps=1e-06, is_rms_norm=True)[source]¶
Forward pass for RMS normalization.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch * seqlen, dim).weight (torch.Tensor) – Weight tensor of shape
(dim,).bias (torch.Tensor | None) – Bias tensor of shape
(dim,)or None.eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
is_rms_norm (bool, optional) – Whether to use RMS norm (vs Layer norm). Defaults to True.
- Returns:
Normalized output tensor of shape
(batch * seqlen, dim).- Return type:
- class MambaInnerFn[source]¶
Bases:
FunctionAutograd function for the fused Mamba inner loop.
- static forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-06)¶
Forward pass of the fused Mamba inner function.
- Parameters:
ctx (Any) – Autograd context.
xz (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weights of shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d biases of shape
(dim,).x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape
(delta_rank + 2*dstate, dim).delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape
(dim, delta_rank).out_proj_weight (torch.Tensor) – Output projection weights. Shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection biases. Shape
(d_model,).A (torch.Tensor) – State matrix A. Shape
(dim, dstate).B (torch.Tensor, optional) – State matrix B. Defaults to None.
C (torch.Tensor, optional) – State matrix C. Defaults to None.
D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.
delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.
checkpoint_lvl (int, optional) – Gradient checkpointing level (0 or 1). Defaults to 1.
b_rms_weight (torch.Tensor, optional) – RMS norm weights for B. Defaults to None.
c_rms_weight (torch.Tensor, optional) – RMS norm weights for C. Defaults to None.
dt_rms_weight (torch.Tensor, optional) – RMS norm weights for dt. Defaults to None.
b_c_dt_rms_eps (float, optional) – RMS norm epsilon. Defaults to 1e-6.
- Returns:
The projected output tensor.
- Return type:
- static backward(ctx, dout)¶
Backward pass for the fused Mamba inner function.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor. Shape
(batch, seqlen, d_model).
- Returns:
Gradients with respect to inputs.
- Return type:
- mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-06)[source]¶
Apply the fused Mamba inner function.
- Parameters:
xz (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weights of shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d biases of shape
(dim,).x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape
(delta_rank + 2*dstate, dim).delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape
(dim, delta_rank).out_proj_weight (torch.Tensor) – Output projection weights. Shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection biases. Shape
(d_model,).A (torch.Tensor) – State matrix A. Shape
(dim, dstate).B (torch.Tensor, optional) – State matrix B. Defaults to None.
C (torch.Tensor, optional) – State matrix C. Defaults to None.
D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.
delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.
checkpoint_lvl (int, optional) – Gradient checkpointing level (0 or 1). Defaults to 1.
b_rms_weight (torch.Tensor, optional) – RMS norm weights for B. Defaults to None.
c_rms_weight (torch.Tensor, optional) – RMS norm weights for C. Defaults to None.
dt_rms_weight (torch.Tensor, optional) – RMS norm weights for dt. Defaults to None.
b_c_dt_rms_eps (float, optional) – RMS norm epsilon. Defaults to 1e-6.
- Returns:
The projected output tensor.
- Return type:
- mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True)[source]¶
Reference (pure PyTorch) implementation of the Mamba inner function.
- Parameters:
xz (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen).conv1d_weight (torch.Tensor) – Conv1d weights of shape
(dim, 1, kernel_size).conv1d_bias (torch.Tensor | None) – Conv1d biases of shape
(dim,).x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape
(delta_rank + 2*dstate, dim).delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape
(dim, delta_rank).out_proj_weight (torch.Tensor) – Output projection weights. Shape
(d_model, dim).out_proj_bias (torch.Tensor | None) – Output projection biases. Shape
(d_model,).A (torch.Tensor) – State matrix A. Shape
(dim, dstate).B (torch.Tensor, optional) – State matrix B. Defaults to None.
C (torch.Tensor, optional) – State matrix C. Defaults to None.
D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.
delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.
B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.
C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.
delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.
- Returns:
The projected output tensor.
- Return type: