lrnnx.models.lti.centaurus module

Centaurus: Let SSMs be Conv Nets implementation. https://openreview.net/forum?id=PkpNRmBZ32

class CentaurusBase[source]

Bases: LTI_LRNN, ABC

Common base for Centaurus mode variants (neck, pointwise, dws, full).

Example

>>> # Use via subclasses (CentaurusNeck, CentaurusDWS, CentaurusFull, CentaurusPWNeck)
>>> # or through the Centaurus wrapper
>>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8, discretization="zoh")
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
__init__(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh') None[source]

Initialize CentaurusBase.

Parameters:
  • d_model (int) – The model dimension.

  • d_state (int) – The state dimension.

  • sub_state_dim (int) – The sub-state dimension.

  • discretization (Literal["zoh", "bilinear", "dirac", "async"], optional) – Discretization method. Defaults to "zoh".

compute_kernel() tuple[torch.Tensor, torch.Tensor][source]

Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.

Returns:

A tuple containing:
  • k : Latent kernel of shape (N, L), where N is the number of state channels.

  • empty : Placeholder for compatibility with standard LTI interface expectations.

Return type:

tuple[torch.Tensor, torch.Tensor]

discretize() tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

This method is intentionally not implemented for Centaurus variants.

Raises:

NotImplementedError – Always raised, since Centaurus does not support explicit discretization via this method.

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass through a Centaurus LTI mode variant.

Parameters:
  • x (torch.Tensor) – Input sequence of shape (B, L, H_in).

  • integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.

  • lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.

Returns:

Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.

Return type:

torch.Tensor

allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]

Allocate initial streaming state and cache matrices.

Parameters:
  • batch_size (int) – The batch size.

  • max_seqlen (int, optional) – Maximum sequence length. Defaults to 1.

Returns:

Cache dict with initial state and precomputed discrete parameters.

Return type:

Dict[str, Any]

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Single-timestep streaming update for a Centaurus variant.

This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y : Output tensor of shape (B, H_out) (real).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

class CentaurusNeck[source]

Bases: CentaurusBase

Bottleneck block with dense in/out projections.

Example

>>> model = CentaurusNeck(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusDWS[source]

Bases: CentaurusBase

Depthwise-separable block with one state per channel.

Example

>>> model = CentaurusDWS(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusFull[source]

Bases: CentaurusBase

Fully connected block with a state per (in, out) pair.

Example

>>> model = CentaurusFull(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
class CentaurusPWNeck[source]

Bases: CentaurusBase

Pointwise bottleneck (s5 in public implementations) that flattens (N, M) -> (N*M).

This variant removes E-mixing and repeats delta over M sub-states per state, yielding independent SISO lanes aggregated in a single flattened axis.

Example

>>> model = CentaurusPWNeck(d_model=64, d_state=64, sub_state_dim=8)
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])
compute_kernel() tuple[torch.Tensor, torch.Tensor][source]

Computes the discrete-time latent convolution kernel with intra-state mode mixing using the shared Centaurus formulation.

Returns:

A tuple containing:
  • k : Latent kernel of shape (N, L), where N is the number of state channels.

  • empty : Placeholder for compatibility with standard LTI interface expectations.

Return type:

tuple[torch.Tensor, torch.Tensor]

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor[source]

Forward pass through a Centaurus LTI mode variant.

Parameters:
  • x (torch.Tensor) – Input sequence of shape (B, L, H_in).

  • integration_timesteps (torch.Tensor, optional) – Placeholder for async models. Not used in the current implementation. Defaults to None.

  • lengths (torch.Tensor, optional) – Placeholder for future bidirectional models. Not used in the current implementation. Defaults to None.

Returns:

Output sequence of shape (B, L, H_out), where H_out is the output channel dimension.

Return type:

torch.Tensor

allocate_inference_cache(batch_size: int, max_seqlen: int = 1) Dict[str, Any][source]

Allocate initial streaming state and cache matrices.

Parameters:
  • batch_size (int) – The batch size.

  • max_seqlen (int, optional) – Maximum sequence length. Defaults to 1.

Returns:

Cache dict with initial state and precomputed discrete parameters.

Return type:

Dict[str, Any]

step(x: torch.Tensor, inference_cache: Dict[str, Any]) tuple[torch.Tensor, Dict[str, Any]][source]

Single-timestep streaming update for a Centaurus variant.

This method performs one recurrent update of the Centaurus block using the cached discrete-time parameters in the (B, N, M) layout.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (B, H_in) - the current timestep input.

  • inference_cache (Dict[str, Any]) – Cache from allocate_inference_cache().

Returns:

A tuple containing:
  • y : Output tensor of shape (B, H_out) (real).

  • inference_cache : Updated cache dictionary.

Return type:

tuple[torch.Tensor, Dict[str, Any]]

class Centaurus(d_model: int, d_state: int, sub_state_dim: int, discretization: Literal['zoh', 'bilinear', 'dirac', 'async'] = 'zoh', mode: Literal['neck', 'pointwise', 'pw', 's5', 'dws', 'full'] = 'neck')[source]

Bases: object

Backwards-compatible wrapper that returns a mode-specific class instance.

Example

>>> model = Centaurus(d_model=64, d_state=64, sub_state_dim=8, mode="neck")
>>> x = torch.randn(2, 128, 64)
>>> y = model(x)
>>> y.shape
torch.Size([2, 128, 64])