lrnnx.ops.s7_scan module

S7 Scan Operation. This module exposes 2 levels of the scan similar to Mamba.

class S7ScanFn[source]

Bases: Function

Autograd function for S7 scan with time-varying A, B, C.

static forward(ctx, u, A, B, C, bias=None, return_last_state=False)

Forward pass for the S7 Scan CUDA kernel.

Parameters:
  • ctx (Any) – Autograd context.

  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen) in float32.

  • A (torch.Tensor) – Time-varying state transition tensor of shape (batch, dstate, seqlen) in float32.

  • B (torch.Tensor) – Time-varying input projection tensor of shape (batch, dstate, dim, seqlen) in float32.

  • C (torch.Tensor) – Time-varying output projection tensor of shape (batch, dim, dstate, seqlen) in float32.

  • bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape (batch, dstate, seqlen) in float32. Defaults to None.

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

Returns:

The output tensor, and optionally the last state.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

static backward(ctx, dout)

Backward pass for the S7 Scan CUDA kernel.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of the output tensor.

Returns:

Gradients with respect to inputs (du, dA, dB, dC, dbias, None).

Return type:

tuple

s7_scan_fn(u: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, bias: torch.Tensor | None = None, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

S7 scan using CUDA kernel.

Parameters:
  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen) in float32.

  • A (torch.Tensor) – Time-varying state transition tensor of shape (batch, dstate, seqlen) in float32.

  • B (torch.Tensor) – Time-varying input projection tensor of shape (batch, dstate, dim, seqlen) in float32.

  • C (torch.Tensor) – Time-varying output projection tensor of shape (batch, dim, dstate, seqlen) in float32.

  • bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape (batch, dstate, seqlen) in float32. Defaults to None.

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

Returns:

  • out : Output tensor of shape (batch, dim, seqlen) in float32.

  • last_state : If return_last_state=True, returns the last state of shape (batch, dstate) in float32.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

s7_scan_ref(u: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, bias: torch.Tensor | None = None, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

Reference implementation of S7 scan (pure PyTorch).

Parameters:
  • u (torch.Tensor) – Input tensor of shape (batch, dim, seqlen) in float32.

  • A (torch.Tensor) – Time-varying state transition tensor of shape (batch, dstate, seqlen) in float32.

  • B (torch.Tensor) – Time-varying projection tensor of shape (batch, dstate, dim, seqlen) in float32.

  • C (torch.Tensor) – Time-varying projection tensor of shape (batch, dim, dstate, seqlen) in float32.

  • bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape (batch, dstate, seqlen) in float32. Defaults to None.

  • return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.

Returns:

  • out : Output tensor of shape (batch, dim, seqlen) in float32.

  • last_state : If return_last_state=True, returns the last state of shape (batch, dstate) in float32.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class S7InnerFn[source]

Bases: Function

Fused S7 inner function with custom forward and backward.

static forward(ctx, hidden_states, in_proj_weight, x_proj_weight, gate_proj_weight, d_state, base_params)

Forward pass for the fused S7 inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • hidden_states (torch.Tensor) – Input hidden states of shape (batch, seqlen, d_model).

  • in_proj_weight (torch.Tensor) – Input projection weight matrix.

  • x_proj_weight (torch.Tensor) – X projection weight matrix.

  • gate_proj_weight (torch.Tensor) – Gate projection weight matrix.

  • d_state (int) – State dimension.

  • base_params (torch.Tensor) – Base HiPPO initialization parameters.

Returns:

The projected and gated output tensor of shape (batch, seqlen, d_model).

Return type:

torch.Tensor

static backward(ctx, dout)

Backward pass for the fused S7 inner function.

Parameters:
  • ctx (Any) – Autograd context.

  • dout (torch.Tensor) – Gradient of the output tensor.

Returns:

Gradients with respect to inputs.

Return type:

tuple

s7_inner_fn(hidden_states: torch.Tensor, in_proj_weight: torch.Tensor, x_proj_weight: torch.Tensor, gate_proj_weight: torch.Tensor, d_state: int, base_params: torch.Tensor) torch.Tensor[source]

Fused S7 inner function using CUDA kernel.

Parameters:
  • hidden_states (torch.Tensor) – Input hidden states of shape (batch, seqlen, d_model).

  • in_proj_weight (torch.Tensor) – Input projection weight matrix.

  • x_proj_weight (torch.Tensor) – X projection weight matrix.

  • gate_proj_weight (torch.Tensor) – Gate projection weight matrix.

  • d_state (int) – State dimension.

  • base_params (torch.Tensor) – Base HiPPO initialization parameters.

Returns:

Output tensor of shape (batch, seqlen, d_model).

Return type:

torch.Tensor

s7_inner_ref(hidden_states: torch.Tensor, in_proj_weight: torch.Tensor, x_proj_weight: torch.Tensor, gate_proj_weight: torch.Tensor, d_state: int, base_params: torch.Tensor) torch.Tensor[source]

Reference S7 inner function (pure PyTorch).

Parameters:
  • hidden_states (torch.Tensor) – Input hidden states of shape (batch, seqlen, d_model).

  • in_proj_weight (torch.Tensor) – Input projection weight matrix.

  • x_proj_weight (torch.Tensor) – X projection weight matrix.

  • gate_proj_weight (torch.Tensor) – Gate projection weight matrix.

  • d_state (int) – State dimension.

  • base_params (torch.Tensor) – Base HiPPO initialization parameters.

Returns:

Output tensor of shape (batch, seqlen, d_model).

Return type:

torch.Tensor