lrnnx.ops.triton.simplified_state_update module

Triton kernel for single-step state update of S5-style SSMs. This is to simplified_scan what selective_state_update is to selective_scan.

simplified_state_update(state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor | None = None, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]

Triton-accelerated single-step state update for S5-style (simplified) SSMs.

Parameters:
  • state (torch.Tensor) – Complex hidden state of shape (batch, P), dtype complex64. Modified in-place.

  • x (torch.Tensor) – Real input at current timestep of shape (batch, H), dtype float32.

  • dt (torch.Tensor) – Real timestep of shape (batch, P) or (P,), dtype float32.

  • A (torch.Tensor) – Complex eigenvalues of shape (P,), dtype complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype complex64.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype complex64.

  • D (torch.Tensor, optional) – Real skip connection matrix of shape (H, H), dtype float32. Defaults to None.

  • deltaA (torch.Tensor, optional) – Optional separate timestep for A discretization of shape (batch, P), dtype float32. If None, dt is used for both A and B. Defaults to None.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.

  • conj_sym (bool, optional) – If True, output is 2 * Re(…), else Re(…). Defaults to True.

Returns:

Real output tensor of shape (batch, H), dtype float32.

Return type:

torch.Tensor

simplified_state_update_ref(state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor | None = None, deltaA: torch.Tensor | None = None, discretization: str = 'bilinear', conj_sym: bool = True) torch.Tensor[source]

Pure-PyTorch reference for a single-step S5 state update.

Parameters:
  • state (torch.Tensor) – Complex hidden state of shape (batch, P), dtype complex64. Modified in-place.

  • x (torch.Tensor) – Real input at current timestep of shape (batch, H), dtype float32.

  • dt (torch.Tensor) – Real timestep of shape (batch, P) or (P,), dtype float32.

  • A (torch.Tensor) – Complex eigenvalues of shape (P,), dtype complex64.

  • B (torch.Tensor) – Complex projection matrix of shape (P, H), dtype complex64.

  • C (torch.Tensor) – Complex projection matrix of shape (H, P), dtype complex64.

  • D (torch.Tensor, optional) – Real skip connection matrix of shape (H, H), dtype float32. Defaults to None.

  • deltaA (torch.Tensor, optional) – Optional separate timestep for A discretization of shape (batch, P), dtype float32. Defaults to None.

  • discretization (str, optional) – Discretization method (‘bilinear’, ‘zoh’, or ‘dirac’). Defaults to “bilinear”.

  • conj_sym (bool, optional) – If True, output is 2 * Re(…), else Re(…). Defaults to True.

Returns:

Real output tensor of shape (batch, H), dtype float32.

Return type:

torch.Tensor