Source code for lrnnx.architectures.language_model

"""
Language Model architecture.
Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
"""

import inspect
import json
import math
import os
from collections import namedtuple
from functools import partial
from typing import Any, Dict, Optional

import torch
import torch.nn as nn
from torch import Tensor

from lrnnx.layers.block import Block
from lrnnx.layers.mha import MHA
from lrnnx.layers.mlp import GatedMLP

if torch.cuda.is_available():
    from lrnnx.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
else:
    RMSNorm, layer_norm_fn, rms_norm_fn = nn.RMSNorm, None, None  # type: ignore[assignment, misc]


def _get_mixer_class_from_string(mixer_type: str):
    """
    Get mixer class from string name.

    Args:
        mixer_type (str): Name of the mixer type (e.g., "LRU", "S5", "S6", "S7",
            "Stream", "Centaurus", "Mamba", "attn").

    Returns:
        type: The corresponding PyTorch neural network class.
    """
    from lrnnx.models.lti.centaurus import Centaurus
    from lrnnx.models.lti.lru import LRU
    from lrnnx.models.lti.s4 import S4
    from lrnnx.models.lti.s4d import S4D
    from lrnnx.models.lti.s5 import S5
    from lrnnx.models.ltv.mamba import Mamba
    from lrnnx.models.ltv.rglru import RGLRU
    from lrnnx.models.ltv.s7 import S7

    mixer_registry = {
        "LRU": LRU,
        "S4": S4,
        "S4D": S4D,
        "S5": S5,
        "Centaurus": Centaurus,
        "Mamba": Mamba,
        "RGLRU": RGLRU,
        "S7": S7,
    }

    if mixer_type == "attn":
        return "attn"

    if mixer_type not in mixer_registry:
        raise ValueError(
            f"Unknown mixer type: {mixer_type}. "
            f"Available types: {list(mixer_registry.keys()) + ['attn']}"
        )

    return mixer_registry[mixer_type]


[docs] def create_block( d_model: int, d_state: int, d_intermediate: int, mixer_type: str, mixer_kwargs: Optional[Dict] = None, attn_cfg: Optional[Dict] = None, norm_epsilon: float = 1e-5, rms_norm: bool = False, residual_in_fp32: bool = False, fused_add_norm: bool = True, layer_idx: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Block: """ Create a block. Args: d_model (int): Model dimension. d_state (int): State dimension. d_intermediate (int): Intermediate dimension for MLP layers (0 to disable MLP). mixer_type (str): Name of the mixer type (e.g., "LRU", "S5", "attn"). mixer_kwargs (dict, optional): Additional arguments for mixer. Defaults to None. attn_cfg (dict, optional): Configuration for attention layers. Defaults to None. norm_epsilon (float, optional): Epsilon value for layer normalization. Defaults to 1e-5. rms_norm (bool, optional): Whether to use RMSNorm instead of LayerNorm. Defaults to False. residual_in_fp32 (bool, optional): Whether to compute residuals in float32. Defaults to False. fused_add_norm (bool, optional): Whether to use fused add+norm operations. Defaults to True. layer_idx (int, optional): Index of the current layer. Defaults to None. device (torch.device, optional): Device to place tensors on. Defaults to None. dtype (torch.dtype, optional): Data type for tensors. Defaults to None. Returns: Block: A configured block module. """ if attn_cfg is None: attn_cfg = {} if mixer_kwargs is None: mixer_kwargs = {} factory_kwargs = {"device": device, "dtype": dtype} if mixer_type == "attn": mixer_cls = partial( MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs ) else: mixer_cls_from_str = _get_mixer_class_from_string(mixer_type) # Only pass d_state if the mixer class accepts it sig = inspect.signature(mixer_cls_from_str) if "d_state" in sig.parameters: mixer_cls = partial( mixer_cls_from_str, d_state=d_state, **mixer_kwargs ) else: mixer_cls = partial(mixer_cls_from_str, **mixer_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs, ) if d_intermediate == 0: mlp_cls = nn.Identity # type: ignore[assignment] else: mlp_cls = partial( # type: ignore[assignment] GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs, # type: ignore[arg-type] ) block = Block( d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) block.layer_idx = layer_idx # type: ignore[assignment] return block
def _init_weights( module: nn.Module, n_layer: int, initializer_range: float = 0.02, rescale_prenorm_residual: bool = True, n_residuals_per_layer: int = 1, # change to 2 if we have MLP ) -> None: """ Initialize weights following GPT-2 scheme. Args: module (nn.Module): Module to initialize. n_layer (int): Number of layers in the model. initializer_range (float, optional): Standard deviation for weight initialization. Defaults to 0.02. rescale_prenorm_residual (bool, optional): Whether to rescale prenorm residual weights. Defaults to True. n_residuals_per_layer (int, optional): Number of residual connections per layer. Defaults to 1. """ if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) # type: ignore[misc]
[docs] class LRNNModel(nn.Module): """ Core LRNN backbone. Args: d_model (int): Model dimension. d_state (int): State dimension. n_layer (int): Number of layers in the model. vocab_size (int): Size of the vocabulary. mixer_types (list): List of mixer type names for each layer (e.g., ``["S5", "S7", "attn", ...]``). d_intermediate (int, optional): Intermediate dimension for MLP layers (0 to disable MLP). Defaults to 0. mixer_kwargs (dict, optional): Additional arguments for mixer. Should be a dict mapping mixer type names to their kwargs, e.g., ``{"S5": {"dt_min": 0.001}, "attn": {"num_heads": 8}}``. If a single dict is provided without mixer type keys, it will be applied to all mixers. Defaults to None. mlp_cls (type, optional): MLP class to use. Defaults to None. norm_epsilon (float, optional): Epsilon value for layer normalization. Defaults to 1e-5. rms_norm (bool, optional): Whether to use RMSNorm instead of LayerNorm. Defaults to True. initializer_cfg (dict, optional): Configuration for weight initialization. Defaults to None. fused_add_norm (bool, optional): Whether to use fused add+norm operations. Defaults to True. residual_in_fp32 (bool, optional): Whether to compute residuals in float32. Defaults to False. device (torch.device, optional): Device to place tensors on. Defaults to None. dtype (torch.dtype, optional): Data type for tensors. Defaults to None. """ def __init__( self, d_model: int, d_state: int, n_layer: int, vocab_size: int, mixer_types: list, d_intermediate: int = 0, mixer_kwargs: Optional[Dict] = None, mlp_cls=None, norm_epsilon: float = 1e-5, rms_norm: bool = True, initializer_cfg: Optional[Dict[str, Any]] = None, fused_add_norm: bool = True, residual_in_fp32: bool = False, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm if self.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError( "Failed to import Triton LayerNorm / RMSNorm kernels" ) # mixer_types should have n_layer entries if len(mixer_types) != n_layer: raise ValueError( f"mixer_types must have length n_layer ({n_layer}), " f"got {len(mixer_types)}" ) if mixer_kwargs is None: mixer_kwargs = {} if mlp_cls is None: mlp_cls = GatedMLP # embedding layer self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) # type: ignore[arg-type] # LRNN layers self.layers = nn.ModuleList( [ create_block( d_model, d_state=d_state, d_intermediate=d_intermediate, mixer_type=mixer_types[i], mixer_kwargs=( mixer_kwargs.get(mixer_types[i], {}) if isinstance(mixer_kwargs, dict) else {} ), attn_cfg=( mixer_kwargs.get("attn", {}) if isinstance(mixer_kwargs, dict) and mixer_types[i] == "attn" else {} ), norm_epsilon=norm_epsilon, rms_norm=rms_norm, residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, layer_idx=i, **factory_kwargs, # type: ignore[arg-type] ) for i in range(n_layer) ] ) # normalization norm_cls = RMSNorm if rms_norm else nn.LayerNorm self.norm_f = norm_cls(d_model, eps=norm_epsilon, **factory_kwargs) # initialize weights self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), n_residuals_per_layer=( 1 if d_intermediate == 0 else 2 ), # 2 if we have MLP ) )
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int, dtype: Optional[torch.dtype] = None, **kwargs, ) -> Dict: """ Allocate inference cache for autoregressive generation. Args: batch_size (int): Batch size for inference. max_seqlen (int): Maximum sequence length for inference. dtype (torch.dtype, optional): Data type for cache tensors. Returns: dict: Dictionary mapping layer indices to their allocated caches. """ cache = {} for i, layer in enumerate(self.layers): try: cache[i] = layer.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) except TypeError: # fallback for now cache[i] = layer.allocate_inference_cache(batch_size, **kwargs) return cache
[docs] def step( self, input_ids: Tensor, caches: Dict, integration_timesteps: Optional[Tensor] = None, ) -> Tensor: """ Single-step inference for autoregressive generation. Args: input_ids (torch.Tensor): Input token IDs of shape ``(B, 1)`` — single token. caches (Dict): Dictionary mapping layer indices to their cached states. integration_timesteps (torch.Tensor, optional): Integration timesteps for LTV models (shape: ``(B, 1)`` or ``(B,)``). Defaults to None. Returns: torch.Tensor: Hidden states of shape ``(B, 1, d_model)``. """ hidden_states = self.embedding(input_ids) residual = None for i, layer in enumerate(self.layers): layer_cache = caches.get(i) # norm if not self.fused_add_norm: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) normed = layer.norm(residual.to(dtype=layer.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: normed, residual = layer_norm_fn( hidden_states, layer.norm.weight, layer.norm.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=layer.norm.eps, is_rms_norm=isinstance(layer.norm, RMSNorm), ) # check LTV/LTI from lrnnx.models.ltv.base import LTV_LRNN if isinstance(layer.mixer, LTV_LRNN): mixer_out, updated_cache = layer.mixer.step( normed, layer_cache, # type: ignore[arg-type] integration_timesteps=integration_timesteps, ) caches[i] = updated_cache hidden_states = mixer_out elif isinstance(layer_cache, dict): # Dict-based cache (S4, S4D, Centaurus, S5, LRU) mixer_out, updated_cache = layer.mixer.step( normed.squeeze(1), layer_cache ) caches[i] = updated_cache hidden_states = mixer_out.unsqueeze(1) # (B, H) -> (B, 1, H) if layer.mlp is not None: if not self.fused_add_norm: residual = hidden_states + residual normed2 = layer.norm2( residual.to(dtype=layer.norm2.weight.dtype) ) if self.residual_in_fp32: residual = residual.to(torch.float32) else: normed2, residual = layer_norm_fn( hidden_states, layer.norm2.weight, layer.norm2.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=layer.norm2.eps, is_rms_norm=isinstance(layer.norm2, RMSNorm), ) hidden_states = layer.mlp(normed2) # norm if not self.fused_add_norm: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) hidden_states = self.norm_f( residual.to(dtype=self.norm_f.weight.dtype) ) else: hidden_states = layer_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm_f, RMSNorm), ) return hidden_states
[docs] def forward( self, input_ids: Tensor, inference_params: Optional[Dict] = None, integration_timesteps: Optional[Tensor] = None, lengths: Optional[Tensor] = None, **mixer_kwargs, ) -> Tensor: """ Forward pass of the LRNN backbone. Args: input_ids (torch.Tensor): Input token IDs of shape ``(B, L)``. inference_params (Dict, optional): Parameters for inference mode. Defaults to None. integration_timesteps (torch.Tensor, optional): Timesteps for LTV models (shape: ``(B, L)``). Defaults to None. lengths (torch.Tensor, optional): Sequence lengths for variable-length sequences (shape: ``(B,)``). Defaults to None. Returns: torch.Tensor: Hidden states of shape ``(B, L, d_model)``. """ hidden_states = self.embedding(input_ids) residual = None if integration_timesteps is not None: mixer_kwargs["integration_timesteps"] = integration_timesteps if lengths is not None: mixer_kwargs["lengths"] = lengths for layer in self.layers: hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs, ) if not self.fused_add_norm: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) hidden_states = self.norm_f( residual.to(dtype=self.norm_f.weight.dtype) ) else: hidden_states = layer_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm_f, RMSNorm), ) return hidden_states
[docs] class LRNNLMHeadModel(nn.Module): """ LRNN Language Model with a language modeling head. Args: d_model (int): Model dimension. d_state (int): State dimension. n_layer (int): Number of layers in the model. vocab_size (int): Size of the vocabulary. mixer_types (list): List of mixer type names for each layer (e.g., ``["S5", "S7", "attn", ...]``). d_intermediate (int, optional): Intermediate dimension for MLP layers (0 to disable MLP). Defaults to 0. mixer_kwargs (dict, optional): Additional arguments for mixer. Defaults to None. mlp_cls (type, optional): MLP class to use. Defaults to None. norm_epsilon (float, optional): Epsilon value for layer normalization. Defaults to 1e-5. rms_norm (bool, optional): Whether to use RMSNorm instead of LayerNorm. Defaults to True. initializer_cfg (dict, optional): Configuration for weight initialization. Defaults to None. fused_add_norm (bool, optional): Whether to use fused add+norm operations. Defaults to True. residual_in_fp32 (bool, optional): Whether to compute residuals in float32. Defaults to False. tie_embeddings (bool, optional): Whether to tie input and output embeddings. Defaults to True. pad_vocab_size_multiple (int, optional): Pad vocabulary size to multiple of this value. Defaults to 8. device (torch.device, optional): Device to place tensors on. Defaults to None. dtype (torch.dtype, optional): Data type for tensors. Defaults to None. """ def __init__( self, d_model: int, d_state: int, n_layer: int, vocab_size: int, mixer_types: list, d_intermediate: int = 0, mixer_kwargs: Optional[Dict] = None, mlp_cls=None, norm_epsilon: float = 1e-5, rms_norm: bool = True, fused_add_norm: bool = True, residual_in_fp32: bool = False, tie_embeddings: bool = True, pad_vocab_size_multiple: int = 8, initializer_cfg: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.d_model = d_model self.d_state = d_state self.n_layer = n_layer self.vocab_size = vocab_size self.mixer_types = mixer_types self.d_intermediate = d_intermediate self.norm_epsilon = norm_epsilon self.rms_norm = rms_norm self.fused_add_norm = fused_add_norm self.residual_in_fp32 = residual_in_fp32 self.tie_embeddings = tie_embeddings self.pad_vocab_size_multiple = pad_vocab_size_multiple factory_kwargs = {"device": device, "dtype": dtype} # pad vocabulary size padded_vocab_size = vocab_size if vocab_size % pad_vocab_size_multiple != 0: padded_vocab_size += pad_vocab_size_multiple - ( vocab_size % pad_vocab_size_multiple ) # core LRNN model self.backbone = LRNNModel( d_model=d_model, d_state=d_state, n_layer=n_layer, vocab_size=padded_vocab_size, mixer_types=mixer_types, mixer_kwargs=mixer_kwargs, d_intermediate=d_intermediate, mlp_cls=mlp_cls, norm_epsilon=norm_epsilon, rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, **factory_kwargs, # type: ignore[arg-type] ) # language modeling head self.lm_head = nn.Linear( d_model, padded_vocab_size, bias=False, **factory_kwargs # type: ignore[arg-type] ) # initialize weights self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), ) ) # tie embeddings if specified if tie_embeddings: self.tie_weights()
[docs] def tie_weights(self) -> None: """ Tie input and output embeddings. This makes the embedding layer and language modeling head share the same weights, which is a common practice to reduce parameters and improve performance. """ self.lm_head.weight = self.backbone.embedding.weight
[docs] def allocate_inference_cache( self, batch_size: int, max_seqlen: int, dtype: Optional[torch.dtype] = None, **kwargs, ) -> Dict: """ Allocate inference cache. Args: batch_size (int): Batch size for inference. max_seqlen (int): Maximum sequence length for inference. dtype (torch.dtype, optional): Data type for cache tensors. Returns: dict: Dictionary mapping layer indices to their allocated caches. """ return self.backbone.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs )
[docs] def step( self, input_ids: Tensor, caches: Dict, integration_timesteps: Optional[Tensor] = None, ) -> namedtuple: # type: ignore[valid-type] """ Single-step inference for autoregressive generation. Args: input_ids (torch.Tensor): Input token IDs of shape ``(B, 1)`` — single token. caches (Dict): Dictionary mapping layer indices to their cached states. integration_timesteps (torch.Tensor, optional): Integration timesteps for LTV models (shape: ``(B, 1)`` or ``(B,)``). Defaults to None. Returns: namedtuple: Contains logits tensor of shape ``(B, 1, vocab_size)``. """ # get hidden states hidden_states = self.backbone.step( input_ids, caches, integration_timesteps ) # compute logits lm_logits = self.lm_head(hidden_states) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits)
[docs] def forward( self, input_ids: Tensor, position_ids: Optional[Tensor] = None, inference_params: Optional[Dict] = None, num_last_tokens: int = 0, integration_timesteps: Optional[Tensor] = None, lengths: Optional[Tensor] = None, **mixer_kwargs, ) -> namedtuple: # type: ignore[valid-type] """ Forward pass of the language model. Args: input_ids (torch.Tensor): Input token IDs of shape ``(B, L)``. position_ids (torch.Tensor, optional): Position IDs (unused, for compatibility). Defaults to None. inference_params (Dict, optional): Parameters for inference mode. Defaults to None. num_last_tokens (int, optional): If > 0, only return logits for last n tokens. Defaults to 0. integration_timesteps (torch.Tensor, optional): Timesteps for LTV models (shape: ``(B, L)``). Defaults to None. lengths (torch.Tensor, optional): Sequence lengths for variable-length sequences (shape: ``(B,)``). Defaults to None. Returns: namedtuple: Contains logits tensor of shape ``(B, L, vocab_size)``. """ # get hidden states from backbone hidden_states = self.backbone( input_ids, inference_params=inference_params, integration_timesteps=integration_timesteps, lengths=lengths, **mixer_kwargs, ) # only keep last n tokens if specified if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] # compute logits lm_logits = self.lm_head(hidden_states) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits)
[docs] def save_pretrained(self, save_directory: str) -> None: """ Save the model and configuration to a directory. Args: save_directory (str): Directory path where model and config will be saved. """ # create directory if it doesn't exist os.makedirs(save_directory, exist_ok=True) # save model state dict model_path = os.path.join(save_directory, "pytorch_model.bin") torch.save(self.state_dict(), model_path) # save configuration config_dict = { "d_model": self.d_model, "d_state": self.d_state, "n_layer": self.n_layer, "vocab_size": self.vocab_size, "mixer_types": self.mixer_types, "d_intermediate": self.d_intermediate, "norm_epsilon": self.norm_epsilon, "rms_norm": self.rms_norm, "fused_add_norm": self.fused_add_norm, "residual_in_fp32": self.residual_in_fp32, "tie_embeddings": self.tie_embeddings, "pad_vocab_size_multiple": self.pad_vocab_size_multiple, } config_path = os.path.join(save_directory, "config.json") with open(config_path, "w") as f: json.dump(config_dict, f, indent=4)
[docs] @classmethod def from_pretrained( cls, pretrained_model_path: str, mixer_kwargs: Optional[Dict] = None, mlp_cls=None, initializer_cfg: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs, ) -> "LRNNLMHeadModel": """ Load a pretrained model from a directory. Args: pretrained_model_path (str): Path to directory containing saved model and config. mixer_kwargs (dict, optional): Additional keyword arguments for mixer. Defaults to None. mlp_cls (type, optional): MLP class to use. Defaults to None. initializer_cfg (dict, optional): Configuration for weight initialization. Defaults to None. device (torch.device, optional): Device to place tensors on. Defaults to None. dtype (torch.dtype, optional): Data type for tensors. Defaults to None. Returns: LRNNLMHeadModel: Loaded model instance. """ # load configuration config_path = os.path.join(pretrained_model_path, "config.json") with open(config_path, "r") as f: config_dict = json.load(f) # create model model = cls( d_model=config_dict["d_model"], d_state=config_dict["d_state"], n_layer=config_dict["n_layer"], vocab_size=config_dict["vocab_size"], mixer_types=config_dict["mixer_types"], d_intermediate=config_dict.get("d_intermediate", 0), mixer_kwargs=mixer_kwargs, mlp_cls=mlp_cls, norm_epsilon=config_dict.get("norm_epsilon", 1e-5), rms_norm=config_dict.get("rms_norm", True), fused_add_norm=config_dict.get("fused_add_norm", True), residual_in_fp32=config_dict.get("residual_in_fp32", False), tie_embeddings=config_dict.get("tie_embeddings", True), pad_vocab_size_multiple=config_dict.get( "pad_vocab_size_multiple", 8 ), initializer_cfg=initializer_cfg, device=device, dtype=dtype, **kwargs, ) # load state dict model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) return model