lrnnx.utils.generation module¶
CUDA-graph-accelerated step-by-step inference for LRNN models.
Inspired by https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/generation.py
Usage:
cache = capture_graph(model, batch_size=4, H=64)
y = generate(model, x0, num_steps=512, graph_cache=cache) # CUDA-graph replay
y = generate(model, x0, num_steps=512) # plain fallback
- class CUDAGraphStepCache(graph: torch.cuda.CUDAGraph, x_buf: torch.Tensor, y_buf: torch.Tensor, state_buf: torch.Tensor, mempool: int, batch_size: int, dt_buf: torch.Tensor | None = None, _state_bufs: list = <factory>, _inference_cache: dict | None = None)[source]¶
Bases:
objectHolds a captured CUDA graph and the fixed-address buffers it operates on.
Create instances via
capture_graph- not directly.- Variables:
graph (torch.cuda.CUDAGraph) – The captured CUDA graph.
x_buf (torch.Tensor) – Fixed-address input buffer.
y_buf (torch.Tensor) – Fixed-address output buffer.
state_buf (torch.Tensor) – Fixed-address state buffer.
mempool (int) – CUDA memory pool identifier.
batch_size (int) – Batch size the graph was captured for.
dt_buf (torch.Tensor or None) – Integration timestep buffer for event-driven models.
- capture_graph(model: LTI_LRNN | LTV_LRNN, batch_size: int, H: int, max_seqlen: int = 1, event_mode: bool = False, device: torch.device | str | None = None, n_warmups: int = 3) CUDAGraphStepCache¶
Capture the model’s single-step recurrence as a CUDA graph.
Call this once (outside the hot loop) and pass the returned
CUDAGraphStepCachetogeneratefor zero-overhead replay.- Parameters:
model (LTI_LRNN | LTV_LRNN) – An lrnnx model on CUDA in eval mode.
batch_size (int) – Batch size to capture for. Every subsequent generate call must use the same batch size.
H (int) – Model input/output dimension.
max_seqlen (int, optional) – Maximum sequence length (passed to allocate_inference_cache). Defaults to 1.
event_mode (bool, optional) – If True, capture with an integration_timesteps input buffer so that event-driven timesteps can be supplied at replay time. Defaults to False.
device (torch.device | str | None, optional) – CUDA device. Inferred from model parameters if None. Defaults to None.
n_warmups (int, optional) – Number of warm-up iterations before capture. Defaults to 3.
- Returns:
Opaque handle - pass it as graph_cache to
generate.- Return type:
- generate(model: LTI_LRNN | LTV_LRNN, x: Tensor, num_steps: int, graph_cache: CUDAGraphStepCache | None = None, integration_timesteps: Tensor | None = None) Tensor[source]¶
Autoregressive generation: feed each output back as the next input.
When
graph_cacheis provided the pre-captured CUDA graph is replayed for every timestep - no re-capture, no extra overhead. When None, falls back to a plain Python loop.- Parameters:
model (LTI_LRNN | LTV_LRNN) – An lrnnx model on CUDA in eval mode.
x (torch.Tensor) – Seed input, shape
(batch, H).num_steps (int) – Number of autoregressive steps to generate.
graph_cache (CUDAGraphStepCache | None, optional) – Pre-captured graph from
capture_graph. Defaults to None.integration_timesteps (torch.Tensor | None, optional) – Integration timestep shape
(batch, 1)for event-driven models, reused at every generated step. Requires that graph_cache was captured with event_mode=True when using the CUDA-graph path. Defaults to None.
- Returns:
Generated output sequence, shape
(batch, num_steps, H).- Return type: