lrnnx.ops.triton.simplified_state_update module¶
Triton kernel for single-step state update of S5-style SSMs. This is to simplified_scan what selective_state_update is to selective_scan.
- simplified_state_update(state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor | None = None, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]¶
Triton-accelerated single-step state update for S5-style (simplified) SSMs.
- Parameters:
state (torch.Tensor) – Complex hidden state of shape
(batch, P), dtypecomplex64. Modified in-place.x (torch.Tensor) – Real input at current timestep of shape
(batch, H), dtypefloat32.dt (torch.Tensor) – Real timestep of shape
(batch, P)or(P,), dtypefloat32.A (torch.Tensor) – Complex eigenvalues of shape
(P,), dtypecomplex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtypecomplex64.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtypecomplex64.D (torch.Tensor, optional) – Real skip connection matrix of shape
(H, H), dtypefloat32. Defaults to None.deltaA (torch.Tensor, optional) – Optional separate timestep for A discretization of shape
(batch, P), dtypefloat32. If None,dtis used for both A and B. Defaults to None.discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.
conj_sym (bool, optional) – If True, output is 2 * Re(…), else Re(…). Defaults to True.
- Returns:
Real output tensor of shape
(batch, H), dtypefloat32.- Return type:
- simplified_state_update_ref(state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor | None = None, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]¶
Pure-PyTorch reference for a single-step S5 state update.
- Parameters:
state (torch.Tensor) – Complex hidden state of shape
(batch, P), dtypecomplex64. Modified in-place.x (torch.Tensor) – Real input at current timestep of shape
(batch, H), dtypefloat32.dt (torch.Tensor) – Real timestep of shape
(batch, P)or(P,), dtypefloat32.A (torch.Tensor) – Complex eigenvalues of shape
(P,), dtypecomplex64.B (torch.Tensor) – Complex projection matrix of shape
(P, H), dtypecomplex64.C (torch.Tensor) – Complex projection matrix of shape
(H, P), dtypecomplex64.D (torch.Tensor, optional) – Real skip connection matrix of shape
(H, H), dtypefloat32. Defaults to None.deltaA (torch.Tensor, optional) – Optional separate timestep for A discretization of shape
(batch, P), dtypefloat32. Defaults to None.discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.
conj_sym (bool, optional) – If True, output is 2 * Re(…), else Re(…). Defaults to True.
- Returns:
Real output tensor of shape
(batch, H), dtypefloat32.- Return type: