lrnnx.models.lti.centaurus module¶
Centaurus: Let SSMs be Conv Nets implementation. https://openreview.net/forum?id=PkpNRmBZ32
- class CentaurusBase[source]¶
-
Common base for Centaurus mode variants (neck, pointwise, dws, full).
Example
>>> # Use via subclasses (CentaurusNeck, CentaurusDWS, CentaurusFull, CentaurusPWNeck) >>> # or through the Centaurus wrapper >>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8, discretization="zoh") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- __init__(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh') None[source]¶
Initialize CentaurusBase.
- compute_kernel() tuple[torch.Tensor, torch.Tensor][source]¶
Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.
- Returns:
- A tuple containing:
k : Latent kernel of shape (N, L), where N is the number of state channels.
empty : Placeholder for compatibility with standard LTI interface expectations.
- Return type:
- discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
This method is intentionally not implemented for Centaurus variants.
- Raises:
NotImplementedError – Always raised, since Centaurus does not support explicit discretization via this method.
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass through a Centaurus LTI mode variant.
- Parameters:
x (torch.Tensor) – Input sequence of shape (B, L, H_in).
integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.
lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.
- Returns:
Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]¶
Allocate initial streaming state and cache matrices.
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Single-timestep streaming update for a Centaurus variant.
This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.
- Parameters:
x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.
inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().
- Returns:
- A tuple containing:
y : Output tensor of shape
(B, H_out)(real).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- class CentaurusNeck[source]¶
Bases:
CentaurusBaseBottleneck block with dense in/out projections.
Example
>>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusDWS[source]¶
Bases:
CentaurusBaseDepthwise-separable block with one state per channel.
Example
>>> model = CentaurusDWS(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusFull[source]¶
Bases:
CentaurusBaseFully connected block with a state per (in, out) pair.
Example
>>> model = CentaurusFull(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- class CentaurusPWNeck[source]¶
Bases:
CentaurusBasePointwise bottleneck (s5 in public implementations) that flattens (N, M) -> (N*M).
This variant removes E-mixing and repeats delta over M sub-states per state, yielding independent SISO lanes aggregated in a single flattened axis.
Example
>>> model = CentaurusPWNeck(d_model=64, d_state=64, sub_state_dim=8) >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])
- compute_kernel() tuple[torch.Tensor, torch.Tensor][source]¶
Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.
- Returns:
- A tuple containing:
k : Latent kernel of shape (N, L), where N is the number of state channels.
empty : Placeholder for compatibility with standard LTI interface expectations.
- Return type:
- forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]¶
Forward pass through a Centaurus LTI mode variant.
- Parameters:
x (torch.Tensor) – Input sequence of shape (B, L, H_in).
integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.
lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.
- Returns:
Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.
- Return type:
- allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]¶
Allocate initial streaming state and cache matrices.
- step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]¶
Single-timestep streaming update for a Centaurus variant.
This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.
- Parameters:
x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.
inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().
- Returns:
- A tuple containing:
y : Output tensor of shape
(B, H_out)(real).inference_cache : Updated cache dictionary.
- Return type:
tuple[torch.Tensor, Dict[str, Any]]
- class Centaurus(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh', mode: Literal['neck', 'pointwise', 'pw', 's5', 'dws', 'full'] = 'neck')[source]¶
Bases:
objectBackwards-compatible wrapper that returns a mode-specific class instance.
Example
>>> model = Centaurus(d_model=64, d_state=64, sub_state_dim=8, mode="neck") >>> x = torch.randn(2, 128, 64) >>> y = model(x) >>> y.shape torch.Size([2, 128, 64])