lrnnx.utils.generation module

CUDA-graph-accelerated step-by-step inference for LRNN models.

Inspired by https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/generation.py

Usage:

cache = capture_graph(model, batch_size=4, H=64)
y = generate(model, x0, num_steps=512, graph_cache=cache)  # CUDA-graph replay
y = generate(model, x0, num_steps=512)                     # plain fallback
class CUDAGraphStepCache(graph: torch.cuda.CUDAGraph, x_buf: torch.Tensor, y_buf: torch.Tensor, state_buf: torch.Tensor, mempool: int, batch_size: int, dt_buf: torch.Tensor | None = None, _state_bufs: list = <factory>, _inference_cache: dict | None = None)[source]

Bases: object

Holds a captured CUDA graph and the fixed-address buffers it operates on.

Create instances via capture_graph - not directly.

Variables:
  • graph (torch.cuda.CUDAGraph) – The captured CUDA graph.

  • x_buf (torch.Tensor) – Fixed-address input buffer.

  • y_buf (torch.Tensor) – Fixed-address output buffer.

  • state_buf (torch.Tensor) – Fixed-address state buffer.

  • mempool (int) – CUDA memory pool identifier.

  • batch_size (int) – Batch size the graph was captured for.

  • dt_buf (torch.Tensor or None) – Integration timestep buffer for event-driven models.

capture_graph(model: LTI_LRNN | LTV_LRNN, batch_size: int, H: int, max_seqlen: int = 1, event_mode: bool = False, device: torch.device | str | None = None, n_warmups: int = 3) CUDAGraphStepCache

Capture the model’s single-step recurrence as a CUDA graph.

Call this once (outside the hot loop) and pass the returned CUDAGraphStepCache to generate for zero-overhead replay.

Parameters:
  • model (LTI_LRNN | LTV_LRNN) – An lrnnx model on CUDA in eval mode.

  • batch_size (int) – Batch size to capture for. Every subsequent generate call must use the same batch size.

  • H (int) – Model input/output dimension.

  • max_seqlen (int, optional) – Maximum sequence length (passed to allocate_inference_cache). Defaults to 1.

  • event_mode (bool, optional) – If True, capture with an integration_timesteps input buffer so that event-driven timesteps can be supplied at replay time. Defaults to False.

  • device (torch.device | str | None, optional) – CUDA device. Inferred from model parameters if None. Defaults to None.

  • n_warmups (int, optional) – Number of warm-up iterations before capture. Defaults to 3.

Returns:

Opaque handle - pass it as graph_cache to generate.

Return type:

CUDAGraphStepCache

generate(model: LTI_LRNN | LTV_LRNN, x: Tensor, num_steps: int, graph_cache: CUDAGraphStepCache | None = None, integration_timesteps: Tensor | None = None) Tensor[source]

Autoregressive generation: feed each output back as the next input.

When graph_cache is provided the pre-captured CUDA graph is replayed for every timestep - no re-capture, no extra overhead. When None, falls back to a plain Python loop.

Parameters:
  • model (LTI_LRNN | LTV_LRNN) – An lrnnx model on CUDA in eval mode.

  • x (torch.Tensor) – Seed input, shape (batch, H).

  • num_steps (int) – Number of autoregressive steps to generate.

  • graph_cache (CUDAGraphStepCache | None, optional) – Pre-captured graph from capture_graph. Defaults to None.

  • integration_timesteps (torch.Tensor | None, optional) – Integration timestep shape (batch, 1) for event-driven models, reused at every generated step. Requires that graph_cache was captured with event_mode=True when using the CUDA-graph path. Defaults to None.

Returns:

Generated output sequence, shape (batch, num_steps, H).

Return type:

torch.Tensor