lrnnx.ops.selective_scan module

Original Mamba SSM Scan operation, modified from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py

class SelectiveScanFn[source]

Bases: Function

Autograd function for the Mamba Selective Scan CUDA kernel.

static forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization=None)[source]

Forward pass of the selective scan.

Parameters:
  • ctx (Any) – Autograd context.

  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • delta (torch.Tensor) – Delta tensor of shape (batch, dim, seqlen).

  • A (torch.Tensor) – State matrix A of shape (dim, dstate).

  • B (torch.Tensor) – Input matrix B of shape (batch, dstate, seqlen) or (dim, dstate).

  • C (torch.Tensor) – Output matrix C of shape (batch, dstate, seqlen) or (dim, dstate).

  • D (torch.Tensor, optional) – Skip connection vector of shape (dim,). Defaults to None.

  • z (torch.Tensor, optional) – Gating tensor of shape (batch, dim, seqlen). Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta of shape (dim,). Defaults to None.

  • deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape (batch, dim, seqlen). Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.

  • return_last_state (bool, optional) – Whether to return the final state. Defaults to False.

  • discretization (str, optional) – Discretization method to pass to the kernel. Defaults to None.

Returns:

The output tensor, and optionally the last state.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

static backward(ctx, dout)[source]

Backward pass for the selective scan.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of the output tensor.

  • *args – Additional gradients (e.g., for last_state, which are ignored).

Returns:

Gradients with respect to the inputs.

Return type:

tuple

selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization='mamba')[source]

Apply the CUDA selective scan function.

Parameters:
  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • delta (torch.Tensor) – Delta tensor of shape (batch, dim, seqlen).

  • A (torch.Tensor) – State matrix A of shape (dim, dstate).

  • B (torch.Tensor) – Input matrix B of shape (batch, dstate, seqlen) or (dim, dstate).

  • C (torch.Tensor) – Output matrix C of shape (batch, dstate, seqlen) or (dim, dstate).

  • D (torch.Tensor, optional) – Skip connection vector of shape (dim,). Defaults to None.

  • z (torch.Tensor, optional) – Gating tensor of shape (batch, dim, seqlen). Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta of shape (dim,). Defaults to None.

  • deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape (batch, dim, seqlen). Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.

  • return_last_state (bool, optional) – If True, returns (out, last_state). The last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. Defaults to False.

  • discretization (str, optional) – Discretization method. Defaults to “mamba”.

Returns:

The output tensor, and optionally the last state.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, deltaA=None, delta_softplus=False, return_last_state=False, discretization='mamba')[source]

Reference (pure PyTorch) implementation of the selective scan.

Parameters:
  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • delta (torch.Tensor) – Delta tensor of shape (batch, dim, seqlen).

  • A (torch.Tensor) – State matrix A of shape (dim, dstate).

  • B (torch.Tensor) – Matrix B. Can be shape (dim, dstate), (batch, dim, seqlen), or (batch, groups, dstate, seqlen).

  • C (torch.Tensor) – Matrix C. Can be shape (dim, dstate), (batch, dim, seqlen), or (batch, groups, dstate, seqlen).

  • D (torch.Tensor, optional) – Skip connection vector of shape (dim,). Defaults to None.

  • z (torch.Tensor, optional) – Gating tensor of shape (batch, dim, seqlen). Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta of shape (dim,). Defaults to None.

  • deltaA (torch.Tensor, optional) – Asymmetric delta for A of shape (batch, dim, seqlen). Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to False.

  • return_last_state (bool, optional) – Whether to return the final state. Defaults to False.

  • discretization (str, optional) – Discretization method to use. Defaults to “mamba”.

Returns:

The output tensor of shape (batch, dim, seqlen),

and optionally the last state of shape (batch, dim, dstate).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

rms_norm_forward(x, weight, bias, eps=1e-06, is_rms_norm=True)[source]

Forward pass for RMS normalization.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch * seqlen, dim).

  • weight (torch.Tensor) – Weight tensor of shape (dim,).

  • bias (torch.Tensor | None) – Bias tensor of shape (dim,) or None.

  • eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.

  • is_rms_norm (bool, optional) – Whether to use RMS norm (vs Layer norm). Defaults to True.

Returns:

Normalized output tensor of shape (batch * seqlen, dim).

Return type:

torch.Tensor

class MambaInnerFn[source]

Bases: Function

Autograd function for the fused Mamba inner loop.

static forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-06)

Forward pass of the fused Mamba inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • xz (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weights of shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d biases of shape (dim,).

  • x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape (delta_rank + 2*dstate, dim).

  • delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape (dim, delta_rank).

  • out_proj_weight (torch.Tensor) – Output projection weights. Shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection biases. Shape (d_model,).

  • A (torch.Tensor) – State matrix A. Shape (dim, dstate).

  • B (torch.Tensor, optional) – State matrix B. Defaults to None.

  • C (torch.Tensor, optional) – State matrix C. Defaults to None.

  • D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.

  • B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.

  • C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.

  • checkpoint_lvl (int, optional) – Gradient checkpointing level (0 or 1). Defaults to 1.

  • b_rms_weight (torch.Tensor, optional) – RMS norm weights for B. Defaults to None.

  • c_rms_weight (torch.Tensor, optional) – RMS norm weights for C. Defaults to None.

  • dt_rms_weight (torch.Tensor, optional) – RMS norm weights for dt. Defaults to None.

  • b_c_dt_rms_eps (float, optional) – RMS norm epsilon. Defaults to 1e-6.

Returns:

The projected output tensor.

Return type:

torch.Tensor

static backward(ctx, dout)

Backward pass for the fused Mamba inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of the output tensor. Shape (batch, seqlen, d_model).

Returns:

Gradients with respect to inputs.

Return type:

tuple

mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-06)[source]

Apply the fused Mamba inner function.

Parameters:
  • xz (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weights of shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d biases of shape (dim,).

  • x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape (delta_rank + 2*dstate, dim).

  • delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape (dim, delta_rank).

  • out_proj_weight (torch.Tensor) – Output projection weights. Shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection biases. Shape (d_model,).

  • A (torch.Tensor) – State matrix A. Shape (dim, dstate).

  • B (torch.Tensor, optional) – State matrix B. Defaults to None.

  • C (torch.Tensor, optional) – State matrix C. Defaults to None.

  • D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.

  • B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.

  • C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.

  • checkpoint_lvl (int, optional) – Gradient checkpointing level (0 or 1). Defaults to 1.

  • b_rms_weight (torch.Tensor, optional) – RMS norm weights for B. Defaults to None.

  • c_rms_weight (torch.Tensor, optional) – RMS norm weights for C. Defaults to None.

  • dt_rms_weight (torch.Tensor, optional) – RMS norm weights for dt. Defaults to None.

  • b_c_dt_rms_eps (float, optional) – RMS norm epsilon. Defaults to 1e-6.

Returns:

The projected output tensor.

Return type:

torch.Tensor

mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True)[source]

Reference (pure PyTorch) implementation of the Mamba inner function.

Parameters:
  • xz (torch.Tensor) – Input tensor of shape (batch, dim, seqlen).

  • conv1d_weight (torch.Tensor) – Conv1d weights of shape (dim, 1, kernel_size).

  • conv1d_bias (torch.Tensor | None) – Conv1d biases of shape (dim,).

  • x_proj_weight (torch.Tensor) – Projection weights for B, C, delta. Shape (delta_rank + 2*dstate, dim).

  • delta_proj_weight (torch.Tensor) – Projection weights for delta. Shape (dim, delta_rank).

  • out_proj_weight (torch.Tensor) – Output projection weights. Shape (d_model, dim).

  • out_proj_bias (torch.Tensor | None) – Output projection biases. Shape (d_model,).

  • A (torch.Tensor) – State matrix A. Shape (dim, dstate).

  • B (torch.Tensor, optional) – State matrix B. Defaults to None.

  • C (torch.Tensor, optional) – State matrix C. Defaults to None.

  • D (torch.Tensor, optional) – Skip connection matrix D. Defaults to None.

  • delta_bias (torch.Tensor, optional) – Bias for delta. Defaults to None.

  • B_proj_bias (torch.Tensor, optional) – Bias for B projection. Defaults to None.

  • C_proj_bias (torch.Tensor, optional) – Bias for C projection. Defaults to None.

  • delta_softplus (bool, optional) – Whether to apply softplus to delta. Defaults to True.

Returns:

The projected output tensor.

Return type:

torch.Tensor