Source code for lrnnx.utils.generation

"""
CUDA-graph-accelerated step-by-step inference for LRNN models.

Inspired by https://github.com/state-spaces/mamba/blob/main/mamba_ssm/utils/generation.py

Usage::

    cache = capture_graph(model, batch_size=4, H=64)
    y = generate(model, x0, num_steps=512, graph_cache=cache)  # CUDA-graph replay
    y = generate(model, x0, num_steps=512)                     # plain fallback
"""

from __future__ import annotations

import gc
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch
from torch import Tensor

if TYPE_CHECKING:
    from lrnnx.models.lti.base import LTI_LRNN
    from lrnnx.models.ltv.base import LTV_LRNN

_STATE_KEYS = ("lrnn_state", "conv_state")


def _squeeze_out(y: Tensor) -> Tensor:
    return y.squeeze(1) if y.dim() == 3 else y


def _find_state_tensor(cache_dict: dict) -> Tensor:
    """Return the first Tensor value in a cache dict (fallback for state zeroing)."""
    for v in cache_dict.values():
        if isinstance(v, Tensor):
            return v
    raise ValueError("No Tensor found in inference_cache dict")


[docs] @dataclass class CUDAGraphStepCache: """ Holds a captured CUDA graph and the fixed-address buffers it operates on. Create instances via ``capture_graph`` - not directly. :ivar graph: The captured CUDA graph. :vartype graph: torch.cuda.CUDAGraph :ivar x_buf: Fixed-address input buffer. :vartype x_buf: torch.Tensor :ivar y_buf: Fixed-address output buffer. :vartype y_buf: torch.Tensor :ivar state_buf: Fixed-address state buffer. :vartype state_buf: torch.Tensor :ivar mempool: CUDA memory pool identifier. :vartype mempool: int :ivar batch_size: Batch size the graph was captured for. :vartype batch_size: int :ivar dt_buf: Integration timestep buffer for event-driven models. :vartype dt_buf: torch.Tensor or None """ graph: torch.cuda.CUDAGraph x_buf: Tensor y_buf: Tensor state_buf: Tensor mempool: int batch_size: int dt_buf: Optional[Tensor] = None _state_bufs: list = field(default_factory=list) _inference_cache: Optional[dict] = field(default=None, repr=False)
@torch.no_grad() def capture_graph( model: LTI_LRNN | LTV_LRNN, batch_size: int, H: int, max_seqlen: int = 1, event_mode: bool = False, device: torch.device | str | None = None, n_warmups: int = 3, ) -> CUDAGraphStepCache: """ Capture the model's single-step recurrence as a CUDA graph. Call this **once** (outside the hot loop) and pass the returned ``CUDAGraphStepCache`` to ``generate`` for zero-overhead replay. Args: model (LTI_LRNN | LTV_LRNN): An lrnnx model on CUDA in eval mode. batch_size (int): Batch size to capture for. Every subsequent generate call must use the same batch size. H (int): Model input/output dimension. max_seqlen (int, optional): Maximum sequence length (passed to allocate_inference_cache). Defaults to 1. event_mode (bool, optional): If True, capture with an integration_timesteps input buffer so that event-driven timesteps can be supplied at replay time. Defaults to False. device (torch.device | str | None, optional): CUDA device. Inferred from model parameters if None. Defaults to None. n_warmups (int, optional): Number of warm-up iterations before capture. Defaults to 3. Returns: CUDAGraphStepCache: Opaque handle - pass it as graph_cache to ``generate``. """ if device is None: device = next(model.parameters()).device # Free any stale graph memory before allocating new buffers gc.collect() torch.cuda.empty_cache() inference_cache = model.allocate_inference_cache(batch_size, max_seqlen) for k, v in inference_cache.items(): if isinstance(v, Tensor): inference_cache[k] = v.to(device) x_buf = torch.zeros(batch_size, H, device=device, dtype=torch.float32) dt_buf: Tensor | None = None if event_mode: dt_buf = torch.ones(batch_size, 1, device=device, dtype=torch.float32) from lrnnx.models.ltv.base import LTV_LRNN as _LTV is_ltv = isinstance(model, _LTV) def _step(x_t): kwargs: Dict[str, Any] = {} if dt_buf is not None: kwargs["integration_timesteps"] = dt_buf x_in = x_t.unsqueeze(1) if is_ltv else x_t y, c = model.step(x_in, inference_cache, **kwargs) return _squeeze_out(y), c # Warm-up on a side stream (required before capture) s = torch.cuda.Stream(device=device) s.wait_stream(torch.cuda.current_stream(device)) with torch.cuda.stream(s): for _ in range(n_warmups): _step(x_buf) torch.cuda.current_stream(device).wait_stream(s) # Capture mempool = torch.cuda.graphs.graph_pool_handle() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=mempool): y_buf, _ = _step(x_buf) # Collect state buffers for zeroing before each generation run state_bufs = [ inference_cache[k] for k in _STATE_KEYS if k in inference_cache and isinstance(inference_cache[k], Tensor) ] return CUDAGraphStepCache( graph=graph, x_buf=x_buf, y_buf=y_buf, state_buf=state_bufs[0] if state_bufs else x_buf, mempool=mempool, batch_size=batch_size, dt_buf=dt_buf, _state_bufs=state_bufs, _inference_cache=inference_cache, )
[docs] def generate( model: LTI_LRNN | LTV_LRNN, x: Tensor, num_steps: int, graph_cache: CUDAGraphStepCache | None = None, integration_timesteps: Tensor | None = None, ) -> Tensor: """ Autoregressive generation: feed each output back as the next input. When ``graph_cache`` is provided the pre-captured CUDA graph is **replayed** for every timestep - no re-capture, no extra overhead. When None, falls back to a plain Python loop. Args: model (LTI_LRNN | LTV_LRNN): An lrnnx model on CUDA in eval mode. x (torch.Tensor): Seed input, shape ``(batch, H)``. num_steps (int): Number of autoregressive steps to generate. graph_cache (CUDAGraphStepCache | None, optional): Pre-captured graph from ``capture_graph``. Defaults to None. integration_timesteps (torch.Tensor | None, optional): Integration timestep shape ``(batch, 1)`` for event-driven models, reused at every generated step. Requires that graph_cache was captured with event_mode=True when using the CUDA-graph path. Defaults to None. Returns: torch.Tensor: Generated output sequence, shape ``(batch, num_steps, H)``. """ if x.dim() != 2: raise ValueError(f"Expected x of shape (B, H), got {x.shape}") if graph_cache is not None: return _generate_with_cuda_graph( graph_cache, x, num_steps, integration_timesteps ) return _generate_with_for_loop(model, x, num_steps, integration_timesteps)
def _generate_with_cuda_graph( cache: CUDAGraphStepCache, x: Tensor, num_steps: int, integration_timesteps: Tensor | None = None, ) -> Tensor: """Replay a captured CUDA graph for each autoregressive step.""" batch, H = x.shape if batch != cache.batch_size: raise ValueError( f"Batch size {batch} != captured {cache.batch_size}. " f"Re-capture with capture_graph(model, batch_size={batch})." ) # Reset all recurrent state tensors for buf in cache._state_bufs: buf.zero_() if ( cache._inference_cache is not None and "seqlen_offset" in cache._inference_cache ): cache._inference_cache["seqlen_offset"] = 0 outputs = torch.empty( batch, num_steps, H, device=x.device, dtype=torch.float32 ) cache.x_buf.copy_(x) if cache.dt_buf is not None and integration_timesteps is not None: cache.dt_buf.copy_(integration_timesteps) with torch.inference_mode(): for t in range(num_steps): cache.graph.replay() outputs[:, t, :] = cache.y_buf cache.x_buf.copy_(cache.y_buf) return outputs def _generate_with_for_loop( model: LTI_LRNN | LTV_LRNN, x: Tensor, num_steps: int, integration_timesteps: Tensor | None = None, ) -> Tensor: """Plain Python loop fallback (no CUDA graph).""" batch, H = x.shape device = x.device inference_cache = model.allocate_inference_cache(batch, num_steps) for k, v in inference_cache.items(): if isinstance(v, Tensor): inference_cache[k] = v.to(device) from lrnnx.models.ltv.base import LTV_LRNN as _LTV is_ltv = isinstance(model, _LTV) outputs = torch.empty( batch, num_steps, H, device=device, dtype=torch.float32 ) x_t = x # (B, H) with torch.inference_mode(): for t in range(num_steps): kwargs: Dict[str, Any] = {} if integration_timesteps is not None: kwargs["integration_timesteps"] = integration_timesteps x_in = x_t.unsqueeze(1) if is_ltv else x_t y, inference_cache = model.step(x_in, inference_cache, **kwargs) y_flat = _squeeze_out(y) outputs[:, t, :] = y_flat x_t = y_flat return outputs