lrnnx.models.ltv.s7 module

S7: Selective and Simplified State Space Layers for Sequence Modeling https://arxiv.org/abs/2410.03464

class S7[source]

Bases: LTV_LRNN

S7: Selective and Simplified State Space Layers for Sequence Modeling.

Example

>>> model = S7(d_model=64, d_state=64)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, J: int = 1, use_fast_path: bool = True, layer_idx: int | None = None, device=None, dtype=None)[source]

Initialize S7 model.

Parameters:
  • d_model (int) – Model dimension.

  • d_state (int) – State dimension. Must be divisible by J.

  • J (int, optional) – Number of blocks for initialization. Defaults to 1.

  • use_fast_path (bool, optional) – Whether to use the CUDA fast path if available. Defaults to True.

  • layer_idx (int, optional) – Layer index for multi-layer models, used for caching. Defaults to None.

  • device (torch.device, optional) – Device for the model parameters. Defaults to None.

  • dtype (torch.dtype, optional) – Data type for the model parameters. Defaults to None.

forward(hidden_states: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None, inference_cache: Dict[str, Any] | None = None) torch.Tensor[source]

Forward pass through the S7 layer.

Parameters:
  • hidden_states (torch.Tensor) – Input tensor of shape (B, L, H).

  • integration_timesteps (torch.Tensor, optional) – Timesteps for async/event-driven discretization. Defaults to None.

  • lengths (torch.Tensor, optional) – Lengths of sequences, required for variable-length sequences. Defaults to None.

  • inference_cache (Dict[str, Any], optional) – Cache for autoregressive generation. Defaults to None.

Returns:

Output tensor of shape (B, L, H).

Return type:

torch.Tensor

step(hidden_states: torch.Tensor, inference_cache: Dict[str, Any]) Tuple[torch.Tensor, Dict[str, Any]][source]

Performs a single recurrent step of S7 for autoregressive inference.

Parameters:
  • hidden_states (torch.Tensor) – Input at current timestep, shape (B, 1, H).

  • inference_cache (Dict[str, Any]) – Cache dictionary containing the model state.

Returns:

A tuple containing:
  • out : Output tensor at the current timestep, shape (B, 1, H).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

allocate_inference_cache(batch_size: int, max_seqlen: int, dtype: torch.dtype | None = None) Dict[str, Any][source]

Allocates cache for S7 autoregressive inference.

Parameters:
  • batch_size (int) – The batch size for inference.

  • max_seqlen (int) – Maximum sequence length (unused, kept for interface consistency).

  • dtype (torch.dtype, optional) – Data type for allocated tensors. Defaults to None.

Returns:

Cache dictionary containing “lrnn_state” and “seqlen_offset”.

Return type:

Dict[str, Any]