lrnnx.architectures.classifier module

Classifier using Linear RNN models with support for token embeddings.

Reference: https://github.com/Efficient-Scalable-Machine-Learning/event-ssm

class SequencePooling[source]

Bases: Module

Pooling layer for sequence data with support for variable lengths.

Handles both intermediate pooling (reducing sequence length) and final pooling (creating a single vector representation).

__init__(pooling_type='last', stride=1)[source]

Initialize the pooling layer.

Parameters:
  • pooling_type (str) – Pooling mode (“last”, “mean”, “max”, “stride”)

  • stride (int) – Stride for pooling (only used for intermediate pooling)

forward(x: torch.Tensor, lengths: torch.Tensor | None = None, integration_timesteps: torch.Tensor | None = None) tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None][source]

Pool sequences, either reducing length (intermediate) or to a single vector (final).

Parameters:
  • x (torch.Tensor) – Input of shape (B, L, D).

  • lengths (torch.Tensor, optional) – Actual sequence lengths of shape (B,). Defaults to None.

  • integration_timesteps (torch.Tensor, optional) – Timesteps of shape (B, L). Defaults to None.

Returns:

Pooled tensor (and updated

timesteps / lengths for intermediate pooling).

Return type:

tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]

class ClassifierBlock[source]

Bases: Module

A single processing block in the Classifier.

Each block contains: - LRNN layer for temporal processing (instantiated from lrnn_cls) - Optional intermediate pooling for sequence length reduction - Dropout for regularization - Residual connection - Layer normalization

__init__(d_model, d_state, lrnn_cls: Type[torch.nn.Module], num_classes: int = 0, output_dim: int = 1, pooling: Literal['mean', 'last', 'max'] = 'last', dropout: float = 0.1, intermediate_pooling: Literal['none', 'stride', 'mean', 'max'] = 'none', pooling_factor: int = 2, is_final: bool = False)[source]

Initialize a processing block used inside the classifier.

The block performs sequence processing and when is_final=True produces a single output vector. Set num_classes > 0 to enable classification (the block returns logits over num_classes); otherwise the block produces regression outputs of shape output_dim.

forward(x: torch.Tensor, integration_timesteps: torch.Tensor | None = None, lengths: torch.Tensor | None = None) torch.Tensor | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None][source]

Forward pass through a single classifier block.

Parameters:
  • x (torch.Tensor) – Input of shape (B, L, D).

  • integration_timesteps (torch.Tensor, optional) – Timesteps for LTV models. Defaults to None.

  • lengths (torch.Tensor, optional) – Actual sequence lengths. Defaults to None.

Returns:

Final block returns logits (B, num_classes); non-final blocks return (x, integration_timesteps, lengths).

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]

class Classifier[source]

Bases: Module

Classifier: Sequence classifier or regressor…

Parameters:
  • input_dim (int) – Number of input features.

  • num_classes (int) – Number of output classes.

  • d_model (int) – Hidden dimension of the model.

__init__(input_dim: int, num_classes: int = 0, output_dim: int = 1, d_model: int = 128, d_state: int = 64, n_layers: int = 4, lrnn_cls: Type[torch.nn.Module] | List[Type[torch.nn.Module]] = <class 'lrnnx.models.lti.lru.LRU'>, pooling: Literal['mean', 'last', 'max']='last', dropout: float = 0.1, intermediate_pooling: Literal['none', 'stride', 'mean', 'max'] | ~typing.List[~typing.Literal['none', 'stride', 'mean', 'max']]='none', pooling_factor: int | List[int] = 2, vocab_size: int | None = None, embedding_dim: int | None = None, max_position_embeddings: int | None = None, padding_idx: int | None = 0, lrnn_params: dict | None = None)[source]

Initializes the Classifier.

Parameters:
  • input_dim (int) – Number of input features (ignored when vocab_size is provided).

  • num_classes (int, optional) – Number of output classes. Defaults to 0.

  • output_dim (int, optional) – Number of regression outputs. Defaults to 1.

  • d_model (int, optional) – Hidden dimension of the model. Defaults to 128.

  • d_state (int, optional) – State dimension for the LRNN layers. Defaults to 64.

  • n_layers (int, optional) – Number of LRNN layers. Defaults to 4.

  • lrnn_cls (type | list[type], optional) – Custom LRNN class or list of classes (one per layer) to use. Defaults to LRU.

  • pooling (str, optional) – Pooling strategy for sequence outputs. Defaults to "last".

  • dropout (float, optional) – Dropout probability. Defaults to 0.1.

  • intermediate_pooling (str | list[str], optional) – Pooling strategy for each layer. Defaults to "none".

  • pooling_factor (int | list[int], optional) – Factor by which to reduce sequence length. Defaults to 2.

  • vocab_size (int, optional) – Size of vocabulary for token embeddings. Defaults to None.

  • embedding_dim (int, optional) – Dimension of embeddings (defaults to d_model). Defaults to None.

  • max_position_embeddings (int, optional) – Max sequence length for positional embeddings. Defaults to None.

  • padding_idx (int, optional) – Index of padding token for embedding layer. Defaults to 0.

  • lrnn_params (dict, optional) – Additional parameters for LRNN modules. Defaults to None.

forward(x: torch.Tensor, lengths: torch.Tensor | None = None, integration_timesteps: torch.Tensor | None = None) torch.Tensor | Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None][source]

Forward pass of the classifier/regressor.

Parameters:
  • x (torch.Tensor) – Input tensor. Token IDs of shape (B, L) when using embeddings, or continuous features of shape (B, L, input_dim) otherwise.

  • lengths (torch.Tensor, optional) – Actual sequence lengths of shape (B,). Defaults to None.

  • integration_timesteps (torch.Tensor, optional) – Timesteps of shape (B, L) for LTV models. Defaults to None.

Returns:

Logits of shape (B, num_classes) or regression values

of shape (B, output_dim).

Return type:

torch.Tensor