lrnnx.ops.s7_scan module¶
S7 Scan Operation. This module exposes 2 levels of the scan similar to Mamba.
- class S7ScanFn[source]¶
Bases:
FunctionAutograd function for S7 scan with time-varying A, B, C.
- static forward(ctx, u, A, B, C, bias=None, return_last_state=False)¶
Forward pass for the S7 Scan CUDA kernel.
- Parameters:
ctx (Any) – Autograd context.
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen)in float32.A (torch.Tensor) – Time-varying state transition tensor of shape
(batch, dstate, seqlen)in float32.B (torch.Tensor) – Time-varying input projection tensor of shape
(batch, dstate, dim, seqlen)in float32.C (torch.Tensor) – Time-varying output projection tensor of shape
(batch, dim, dstate, seqlen)in float32.bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape
(batch, dstate, seqlen)in float32. Defaults to None.return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
- Returns:
The output tensor, and optionally the last state.
- Return type:
- static backward(ctx, dout)¶
Backward pass for the S7 Scan CUDA kernel.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor.
- Returns:
Gradients with respect to inputs
(du, dA, dB, dC, dbias, None).- Return type:
- s7_scan_fn(u: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, bias: torch.Tensor | None = None, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
S7 scan using CUDA kernel.
- Parameters:
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen)in float32.A (torch.Tensor) – Time-varying state transition tensor of shape
(batch, dstate, seqlen)in float32.B (torch.Tensor) – Time-varying input projection tensor of shape
(batch, dstate, dim, seqlen)in float32.C (torch.Tensor) – Time-varying output projection tensor of shape
(batch, dim, dstate, seqlen)in float32.bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape
(batch, dstate, seqlen)in float32. Defaults to None.return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
- Returns:
out : Output tensor of shape
(batch, dim, seqlen)in float32.last_state : If
return_last_state=True, returns the last state of shape(batch, dstate)in float32.
- Return type:
- s7_scan_ref(u: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, bias: torch.Tensor | None = None, return_last_state: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]¶
Reference implementation of S7 scan (pure PyTorch).
- Parameters:
u (torch.Tensor) – Input tensor of shape
(batch, dim, seqlen)in float32.A (torch.Tensor) – Time-varying state transition tensor of shape
(batch, dstate, seqlen)in float32.B (torch.Tensor) – Time-varying projection tensor of shape
(batch, dstate, dim, seqlen)in float32.C (torch.Tensor) – Time-varying projection tensor of shape
(batch, dim, dstate, seqlen)in float32.bias (torch.Tensor | None, optional) – Optional LTV bias tensor of shape
(batch, dstate, seqlen)in float32. Defaults to None.return_last_state (bool, optional) – Whether to return the last hidden state. Defaults to False.
- Returns:
out : Output tensor of shape
(batch, dim, seqlen)in float32.last_state : If
return_last_state=True, returns the last state of shape(batch, dstate)in float32.
- Return type:
- class S7InnerFn[source]¶
Bases:
FunctionFused S7 inner function with custom forward and backward.
- static forward(ctx, hidden_states, in_proj_weight, x_proj_weight, gate_proj_weight, d_state, base_params)¶
Forward pass for the fused S7 inner function.
- Parameters:
ctx (Any) – Autograd context.
hidden_states (torch.Tensor) – Input hidden states of shape
(batch, seqlen, d_model).in_proj_weight (torch.Tensor) – Input projection weight matrix.
x_proj_weight (torch.Tensor) – X projection weight matrix.
gate_proj_weight (torch.Tensor) – Gate projection weight matrix.
d_state (int) – State dimension.
base_params (torch.Tensor) – Base HiPPO initialization parameters.
- Returns:
The projected and gated output tensor of shape
(batch, seqlen, d_model).- Return type:
- static backward(ctx, dout)¶
Backward pass for the fused S7 inner function.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor.
- Returns:
Gradients with respect to inputs.
- Return type:
- s7_inner_fn(hidden_states: torch.Tensor, in_proj_weight: torch.Tensor, x_proj_weight: torch.Tensor, gate_proj_weight: torch.Tensor, d_state: int, base_params: torch.Tensor) torch.Tensor[source]¶
Fused S7 inner function using CUDA kernel.
- Parameters:
hidden_states (torch.Tensor) – Input hidden states of shape
(batch, seqlen, d_model).in_proj_weight (torch.Tensor) – Input projection weight matrix.
x_proj_weight (torch.Tensor) – X projection weight matrix.
gate_proj_weight (torch.Tensor) – Gate projection weight matrix.
d_state (int) – State dimension.
base_params (torch.Tensor) – Base HiPPO initialization parameters.
- Returns:
Output tensor of shape
(batch, seqlen, d_model).- Return type:
- s7_inner_ref(hidden_states: torch.Tensor, in_proj_weight: torch.Tensor, x_proj_weight: torch.Tensor, gate_proj_weight: torch.Tensor, d_state: int, base_params: torch.Tensor) torch.Tensor[source]¶
Reference S7 inner function (pure PyTorch).
- Parameters:
hidden_states (torch.Tensor) – Input hidden states of shape
(batch, seqlen, d_model).in_proj_weight (torch.Tensor) – Input projection weight matrix.
x_proj_weight (torch.Tensor) – X projection weight matrix.
gate_proj_weight (torch.Tensor) – Gate projection weight matrix.
d_state (int) – State dimension.
base_params (torch.Tensor) – Base HiPPO initialization parameters.
- Returns:
Output tensor of shape
(batch, seqlen, d_model).- Return type: