lrnnx.ops.triton.selective_state_update module

We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this

selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, state_batch_indices=None, discretization='mamba')[source]

Triton-accelerated single-step state update for selective state space models.

Parameters:
  • state (torch.Tensor) – Hidden state of shape (batch, dim, dstate) or (batch, nheads, dim, dstate).

  • x (torch.Tensor) – Input tensor of shape (batch, dim) or (batch, nheads, dim).

  • dt (torch.Tensor) – Timestep tensor of shape (batch, dim) or (batch, nheads, dim).

  • A (torch.Tensor) – State transition matrix of shape (dim, dstate) or (nheads, dim, dstate).

  • B (torch.Tensor) – Input projection matrix of shape (batch, dstate) or (batch, ngroups, dstate).

  • C (torch.Tensor) – Output projection matrix of shape (batch, dstate) or (batch, ngroups, dstate).

  • D (torch.Tensor, optional) – Skip connection vector of shape (dim,) or (nheads, dim). Defaults to None.

  • z (torch.Tensor, optional) – Gating tensor of shape (batch, dim) or (batch, nheads, dim). Defaults to None.

  • dt_bias (torch.Tensor, optional) – Bias for dt of shape (dim,) or (nheads, dim). Defaults to None.

  • dt_softplus (bool, optional) – Whether to apply softplus to dt. Defaults to False.

  • deltaA (torch.Tensor, optional) – Timestep for A discretization (dtA) in asymmetric mode, shape (batch, dim) or (batch, nheads, dim). Defaults to None.

  • state_batch_indices (torch.Tensor, optional) – Indices to select states for the batch, shape (batch,). Defaults to None.

  • discretization (str, optional) – Discretization method (‘zoh’, ‘bilinear’, ‘dirac’, ‘mamba’, ‘rglru’, ‘s7’). Defaults to “mamba”.

Returns:

The output tensor of shape (batch, dim) or (batch, nheads, dim).

Return type:

torch.Tensor

selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, discretization='mamba')[source]

Reference (pure PyTorch) implementation of the single-step selective state update.

Parameters:
  • state (torch.Tensor) – Hidden state of shape (batch, dim, dstate) or (batch, nheads, dim, dstate).

  • x (torch.Tensor) – Input tensor of shape (batch, dim) or (batch, nheads, dim).

  • dt (torch.Tensor) – Timestep tensor of shape (batch, dim) or (batch, nheads, dim).

  • A (torch.Tensor) – State transition matrix of shape (dim, dstate) or (nheads, dim, dstate).

  • B (torch.Tensor) – Input projection matrix of shape (batch, dstate) or (batch, ngroups, dstate).

  • C (torch.Tensor) – Output projection matrix of shape (batch, dstate) or (batch, ngroups, dstate).

  • D (torch.Tensor, optional) – Skip connection vector of shape (dim,) or (nheads, dim). Defaults to None.

  • z (torch.Tensor, optional) – Gating tensor of shape (batch, dim) or (batch, nheads, dim). Defaults to None.

  • dt_bias (torch.Tensor, optional) – Bias for dt of shape (dim,) or (nheads, dim). Defaults to None.

  • dt_softplus (bool, optional) – Whether to apply softplus to dt. Defaults to False.

  • deltaA (torch.Tensor, optional) – Timestep for A discretization (dtA) in asymmetric mode, shape (batch, dim) or (batch, nheads, dim). Defaults to None.

  • discretization (str, optional) – Discretization method (‘zoh’, ‘bilinear’, ‘dirac’, ‘mamba’, ‘rglru’, ‘s7’). Defaults to “mamba”.

Returns:

The output tensor of shape (batch, dim) or (batch, nheads, dim).

Return type:

torch.Tensor