lrnnx.models.ltv.mamba module¶
- class Mamba[source]¶
Bases:
LTV_LRNNMamba: Selective State Space Model with optional event-based processing.
When integration_timesteps is provided in forward(), uses asymmetric discretization (separate dtA and dtB) for event-driven processing. Otherwise, uses standard Mamba discretization.
Example
>>> model = Mamba(d_model=64, d_state=16, d_conv=4) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model, d_state=16, d_conv=4, expand=2, dt_rank='auto', dt_min=0.001, dt_max=0.1, dt_init='random', dt_scale=1.0, dt_init_floor=0.0001, conv_bias=True, bias=False, use_fast_path=True, layer_idx=None, device=None, dtype=None, discretization='mamba')[source]¶
Initialize Mamba model.
- Parameters:
d_model (int) – Model dimension.
d_state (int, optional) – SSM state dimension (N). Defaults to 16.
d_conv (int, optional) – Convolution kernel size. Defaults to 4.
expand (int, optional) – Expansion factor for inner dimension. Defaults to 2.
dt_rank (Union[int, str], optional) – Rank for delta projection,
"auto"=ceil(d_model / 16). Defaults to"auto".dt_min (float, optional) – Minimum value for delta initialization. Defaults to 0.001.
dt_max (float, optional) – Maximum value for delta initialization. Defaults to 0.1.
dt_init (str, optional) – Initialization method (
"random"or"constant"). Defaults to"random".dt_scale (float, optional) – Scale factor for dt initialization. Defaults to 1.0.
dt_init_floor (float, optional) – Floor value for dt initialization. Defaults to 1e-4.
conv_bias (bool, optional) – Whether to use bias in convolution. Defaults to True.
bias (bool, optional) – Whether to use bias in linear projections. Defaults to False.
use_fast_path (bool, optional) – Whether to use fused CUDA kernels. Defaults to True.
layer_idx (int, optional) – Layer index for multi-layer caching. Defaults to None.
device (torch.device, optional) – Device for parameters. Defaults to None.
dtype (torch.dtype, optional) – Data type for parameters. Defaults to None.
discretization (str, optional) – Discretization type. Defaults to
"mamba".
- forward(hidden_states, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None, inference_cache: Dict[str, Any] | None = None)[source]¶
Forward pass through Mamba.
- Parameters:
hidden_states (torch.Tensor) – Input tensor, shape
(B, L, D).integration_timesteps (torch.Tensor, optional) – Time intervals between events. Shape
(B, L). When provided, uses asymmetric discretization with separate dtA and dtB for event-driven processing. Defaults to None.lengths (torch.Tensor, optional) – Not used by Mamba currently. Defaults to None.
inference_cache (dict, optional) – Cache for autoregressive generation. If provided, contains “conv_state” and “lrnn_state” tensors. Defaults to None.
- Returns:
Output tensor, shape
(B, L, D).- Return type:
- step(x: torch.Tensor, inference_cache: Dict[str, Any], integration_timesteps: torch.Tensor | None = None) Tuple[torch.Tensor, Dict[str, Any]][source]¶
Performs a single recurrent step of Mamba.
- Parameters:
x (torch.Tensor) – Input at current timestep, shape
(B, 1, D).inference_cache (Dict[str, Any]) – Cache dictionary containing: - “conv_state”: Convolution state, shape
(B, D_inner, d_conv)- “lrnn_state”: SSM state, shape(B, D_inner, N)- “seqlen_offset”: Current position in sequenceintegration_timesteps (torch.Tensor, optional) – Integration timestep, shape
(B, 1)or(B,). When provided, uses event-based asymmetric discretization. Defaults to None.
- Returns:
- A tuple containing:
out : Output at current timestep, shape
(B, 1, D).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- allocate_inference_cache(batch_size: int, max_seqlen: int, dtype: torch.dtype | None = None) Dict[str, Any][source]¶
Allocates cache for Mamba autoregressive inference.
- Parameters:
batch_size (int) – The batch size for inference.
max_seqlen (int) – Maximum sequence length (not used by Mamba, but kept for interface consistency).
dtype (torch.dtype, optional) – Data type for allocated tensors. Defaults to None.
- Returns:
- Cache dictionary containing:
”conv_state”: Convolution state, shape
(B, D_inner, d_conv).”lrnn_state”: SSM state, shape
(B, D_inner, N).”seqlen_offset”: Current position in the sequence (starts at 0).
- Return type:
Dict[str, Any]