Source code for lrnnx.architectures.embedding
"""Embedding modules for sequence models."""
from typing import Optional
import torch
import torch.nn as nn
[docs]
class PositionEmbedding(nn.Module):
"""
Learned positional embeddings (position indices -> vectors).
Args:
max_position_embeddings (int): Maximum sequence length supported.
embedding_dim (int): Dimension of the embedding vectors.
"""
def __init__(self, max_position_embeddings: int, embedding_dim: int):
super().__init__()
self.position_embedding = nn.Embedding(
max_position_embeddings, embedding_dim
)
[docs]
def forward(self, positions: torch.Tensor) -> torch.Tensor:
"""
Forward pass for positional embeddings.
Args:
positions (torch.Tensor): Tensor of position indices.
Returns:
torch.Tensor: Positional embeddings.
"""
return self.position_embedding(positions)
[docs]
class TokenEmbedding(nn.Module):
"""
Token embedding module. Positional embeddings are optional and explicit.
By default this returns token lookups only. Enable learned positional
embeddings with `use_position=True` and providing `max_position_embeddings`.
Args:
vocab_size (int): Size of the vocabulary.
embedding_dim (int): Dimension of the embedding vectors.
padding_idx (int, optional): Index for padding tokens. Defaults to None.
max_position_embeddings (int, optional): Max sequence length for positional
embeddings. Required if ``use_position=True``. Defaults to None.
use_position (bool, optional): Whether to include learned positional
embeddings. Defaults to False.
dropout (float, optional): Dropout probability. Defaults to 0.1.
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_position_embeddings: Optional[int] = None,
use_position: bool = False,
dropout: float = 0.1,
):
super().__init__()
self.token_embedding = nn.Embedding(
vocab_size, embedding_dim, padding_idx=padding_idx
)
# Optional positional embeddings
self.position_embedding: Optional[PositionEmbedding] = None
if use_position:
if max_position_embeddings is None:
raise ValueError(
"max_position_embeddings must be set when use_position=True"
)
self.position_embedding = PositionEmbedding(
max_position_embeddings, embedding_dim
)
self.dropout = nn.Dropout(dropout)
self.embedding_dim = embedding_dim
[docs]
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Convert token IDs to embeddings.
Args:
token_ids (torch.Tensor): Tensor of token IDs of shape ``(batch_size, seq_len)``.
Returns:
torch.Tensor: Embedded tokens of shape ``(batch_size, seq_len, embedding_dim)``.
"""
embeddings = self.token_embedding(token_ids)
if self.position_embedding is not None:
seq_len = token_ids.size(1)
positions = (
torch.arange(seq_len, device=token_ids.device)
.unsqueeze(0)
.expand_as(token_ids)
)
pos_emb = self.position_embedding(positions)
embeddings = embeddings + pos_emb
return self.dropout(embeddings)