lrnnx.ops.triton.selective_state_update module¶
We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
- selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, state_batch_indices=None, discretization='mamba')[source]¶
Triton-accelerated single-step state update for selective state space models.
- Parameters:
state (torch.Tensor) – Hidden state of shape
(batch, dim, dstate)or(batch, nheads, dim, dstate).x (torch.Tensor) – Input tensor of shape
(batch, dim)or(batch, nheads, dim).dt (torch.Tensor) – Timestep tensor of shape
(batch, dim)or(batch, nheads, dim).A (torch.Tensor) – State transition matrix of shape
(dim, dstate)or(nheads, dim, dstate).B (torch.Tensor) – Input projection matrix of shape
(batch, dstate)or(batch, ngroups, dstate).C (torch.Tensor) – Output projection matrix of shape
(batch, dstate)or(batch, ngroups, dstate).D (torch.Tensor, optional) – Skip connection vector of shape
(dim,)or(nheads, dim). Defaults to None.z (torch.Tensor, optional) – Gating tensor of shape
(batch, dim)or(batch, nheads, dim). Defaults to None.dt_bias (torch.Tensor, optional) – Bias for dt of shape
(dim,)or(nheads, dim). Defaults to None.dt_softplus (bool, optional) – Whether to apply softplus to dt. Defaults to False.
deltaA (torch.Tensor, optional) – Timestep for A discretization (dtA) in asymmetric mode, shape
(batch, dim)or(batch, nheads, dim). Defaults to None.state_batch_indices (torch.Tensor, optional) – Indices to select states for the batch, shape
(batch,). Defaults to None.discretization (str, optional) – Discretization method (‘zoh’, ‘bilinear’, ‘dirac’, ‘mamba’, ‘rglru’, ‘s7’). Defaults to “mamba”.
- Returns:
The output tensor of shape
(batch, dim)or(batch, nheads, dim).- Return type:
- selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, discretization='mamba')[source]¶
Reference (pure PyTorch) implementation of the single-step selective state update.
- Parameters:
state (torch.Tensor) – Hidden state of shape
(batch, dim, dstate)or(batch, nheads, dim, dstate).x (torch.Tensor) – Input tensor of shape
(batch, dim)or(batch, nheads, dim).dt (torch.Tensor) – Timestep tensor of shape
(batch, dim)or(batch, nheads, dim).A (torch.Tensor) – State transition matrix of shape
(dim, dstate)or(nheads, dim, dstate).B (torch.Tensor) – Input projection matrix of shape
(batch, dstate)or(batch, ngroups, dstate).C (torch.Tensor) – Output projection matrix of shape
(batch, dstate)or(batch, ngroups, dstate).D (torch.Tensor, optional) – Skip connection vector of shape
(dim,)or(nheads, dim). Defaults to None.z (torch.Tensor, optional) – Gating tensor of shape
(batch, dim)or(batch, nheads, dim). Defaults to None.dt_bias (torch.Tensor, optional) – Bias for dt of shape
(dim,)or(nheads, dim). Defaults to None.dt_softplus (bool, optional) – Whether to apply softplus to dt. Defaults to False.
deltaA (torch.Tensor, optional) – Timestep for A discretization (dtA) in asymmetric mode, shape
(batch, dim)or(batch, nheads, dim). Defaults to None.discretization (str, optional) – Discretization method (‘zoh’, ‘bilinear’, ‘dirac’, ‘mamba’, ‘rglru’, ‘s7’). Defaults to “mamba”.
- Returns:
The output tensor of shape
(batch, dim)or(batch, nheads, dim).- Return type: