lrnnx.ops.triton.layer_norm module¶
Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layer_norm.py Copyright (c) 2024, Tri Dao. Implement dropout + residual + layer_norm / rms_norm.
Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. This is faster for dimensions up to 8k, but after that it’s much slower due to register spilling. The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
- layer_norm_ref(x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-06, dropout_p=0.0, rowscale=None, prenorm=False, dropout_mask=None, dropout_mask1=None, upcast=False)[source]¶
Reference (pure PyTorch) implementation of Layer Normalization with optional residual, dropout, and parallel branches.
- Parameters:
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Layer norm weights.
bias (torch.Tensor | None) – Layer norm biases.
residual (torch.Tensor, optional) – Residual tensor to add before normalization. Defaults to None.
x1 (torch.Tensor, optional) – Optional parallel input branch. Defaults to None.
weight1 (torch.Tensor, optional) – Weights for parallel layer norm branch. Defaults to None.
bias1 (torch.Tensor, optional) – Biases for parallel layer norm branch. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
rowscale (torch.Tensor, optional) – Row-wise scaling factor. Defaults to None.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
dropout_mask (torch.Tensor, optional) – Explicit mask for dropout on x. Defaults to None.
dropout_mask1 (torch.Tensor, optional) – Explicit mask for dropout on x1. Defaults to None.
upcast (bool, optional) – Whether to cast inputs to float32 before computation. Defaults to False.
- Returns:
The normalized output. If
prenorm=Trueorweight1is provided, returns a tuple.- Return type:
- rms_norm_ref(x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-06, dropout_p=0.0, rowscale=None, prenorm=False, dropout_mask=None, dropout_mask1=None, upcast=False)[source]¶
Reference (pure PyTorch) implementation of RMS Normalization with optional residual, dropout, and parallel branches.
- Parameters:
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – RMS norm weights.
bias (torch.Tensor | None) – RMS norm biases (added after scaling).
residual (torch.Tensor, optional) – Residual tensor to add before normalization. Defaults to None.
x1 (torch.Tensor, optional) – Optional parallel input branch. Defaults to None.
weight1 (torch.Tensor, optional) – Weights for parallel RMS norm branch. Defaults to None.
bias1 (torch.Tensor, optional) – Biases for parallel RMS norm branch. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
rowscale (torch.Tensor, optional) – Row-wise scaling factor. Defaults to None.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
dropout_mask (torch.Tensor, optional) – Explicit mask for dropout on x. Defaults to None.
dropout_mask1 (torch.Tensor, optional) – Explicit mask for dropout on x1. Defaults to None.
upcast (bool, optional) – Whether to cast inputs to float32 before computation. Defaults to False.
- Returns:
The normalized output. If
prenorm=Trueorweight1is provided, returns a tuple.- Return type:
- config_prune(configs)[source]¶
Filters out Triton configurations that require more warps than the current device supports.
- class LayerNormFn[source]¶
Bases:
FunctionAutograd function for fused Layer/RMS Normalization with optional residual connections.
- static forward(ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-06, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False)[source]¶
Forward pass for the LayerNormFn.
- Parameters:
ctx (Any) – Autograd context.
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Normalization weights.
bias (torch.Tensor | None) – Normalization biases.
residual (torch.Tensor, optional) – Optional residual tensor. Defaults to None.
x1 (torch.Tensor, optional) – Optional parallel input branch. Defaults to None.
weight1 (torch.Tensor, optional) – Optional parallel branch weights. Defaults to None.
bias1 (torch.Tensor, optional) – Optional parallel branch biases. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
rowscale (torch.Tensor, optional) – Optional row-wise scaling factor. Defaults to None.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
residual_in_fp32 (bool, optional) – Whether the residual should be maintained in FP32. Defaults to False.
is_rms_norm (bool, optional) – If True, computes RMS norm instead of Layer norm. Defaults to False.
return_dropout_mask (bool, optional) – If True, returns the generated dropout masks. Defaults to False.
- Returns:
Normalized output, optionally along with residual, y1, and dropout masks.
- Return type:
- static backward(ctx, dy)[source]¶
Backward pass for the LayerNormFn.
- Parameters:
ctx (Any) – Autograd context.
dy (torch.Tensor) – Gradient of the output tensor.
*args – Additional gradients (e.g., dy1, dresidual).
- Returns:
Gradients with respect to all forward inputs.
- Return type:
- layer_norm_fn(x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-06, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False)[source]¶
Applies fused Layer Normalization using Triton.
- Parameters:
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Normalization weights.
bias (torch.Tensor | None) – Normalization biases.
residual (torch.Tensor, optional) – Optional residual tensor to add before norm. Defaults to None.
x1 (torch.Tensor, optional) – Optional parallel input branch. Defaults to None.
weight1 (torch.Tensor, optional) – Optional parallel branch weights. Defaults to None.
bias1 (torch.Tensor, optional) – Optional parallel branch biases. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
rowscale (torch.Tensor, optional) – Optional row-wise scaling factor. Defaults to None.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
residual_in_fp32 (bool, optional) – Maintain residual in FP32. Defaults to False.
is_rms_norm (bool, optional) – If True, computes RMS norm. Defaults to False.
return_dropout_mask (bool, optional) – If True, returns generated dropout masks. Defaults to False.
- Returns:
The normalized output. If
prenorm=True, returns(out, prenorm_state).- Return type:
- rms_norm_fn(x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-06, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False)[source]¶
Applies fused RMS Normalization using Triton.
- Parameters:
x (torch.Tensor) – Input tensor.
weight (torch.Tensor) – Normalization weights.
bias (torch.Tensor | None) – Normalization biases.
residual (torch.Tensor, optional) – Optional residual tensor to add before norm. Defaults to None.
x1 (torch.Tensor, optional) – Optional parallel input branch. Defaults to None.
weight1 (torch.Tensor, optional) – Optional parallel branch weights. Defaults to None.
bias1 (torch.Tensor, optional) – Optional parallel branch biases. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
rowscale (torch.Tensor, optional) – Optional row-wise scaling factor. Defaults to None.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
residual_in_fp32 (bool, optional) – Maintain residual in FP32. Defaults to False.
return_dropout_mask (bool, optional) – If True, returns generated dropout masks. Defaults to False.
- Returns:
The RMS normalized output. If
prenorm=True, returns(out, prenorm_state).- Return type:
- class RMSNorm[source]¶
Bases:
ModuleRMS Normalization Layer.
- Parameters:
hidden_size (int) – Dimension of the features to normalize.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-5.
dropout_p (float, optional) – Dropout probability. Defaults to 0.0.
device (torch.device, optional) – Device for parameters. Defaults to None.
dtype (torch.dtype, optional) – Data type for parameters. Defaults to None.
- forward(x, residual=None, prenorm=False, residual_in_fp32=False)[source]¶
Forward pass for RMSNorm.
- Parameters:
x (torch.Tensor) – Input tensor.
residual (torch.Tensor, optional) – Optional residual connection. Defaults to None.
prenorm (bool, optional) – Whether to return the state before normalization. Defaults to False.
residual_in_fp32 (bool, optional) – Compute residual in float32. Defaults to False.
- Returns:
Normalized output. If
prenorm=True, returns(out, prenorm_state).- Return type:
- class LayerNormLinearFn[source]¶
Bases:
FunctionAutograd function for a fused Layer/RMS Normalization followed immediately by a Linear projection.
- static forward(ctx, x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-06, prenorm=False, residual_in_fp32=False, is_rms_norm=False)¶
Forward pass for the LayerNormLinearFn.
- Parameters:
ctx (Any) – Autograd context.
x (torch.Tensor) – Input tensor.
norm_weight (torch.Tensor) – Normalization weights.
norm_bias (torch.Tensor | None) – Normalization biases.
linear_weight (torch.Tensor) – Linear projection weights.
linear_bias (torch.Tensor | None) – Linear projection biases.
residual (torch.Tensor, optional) – Optional residual tensor. Defaults to None.
eps (float, optional) – Numerical stability epsilon. Defaults to 1e-6.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
residual_in_fp32 (bool, optional) – Whether to maintain the residual in FP32. Defaults to False.
is_rms_norm (bool, optional) – If True, uses RMS norm instead of Layer norm. Defaults to False.
- Returns:
The projected output, optionally along with the prenorm residual state.
- Return type:
- static backward(ctx, dout)¶
Backward pass for LayerNormLinearFn.
- Parameters:
ctx (Any) – Autograd context.
dout (torch.Tensor) – Gradient of the output tensor.
*args – Additional gradients (e.g., dresidual).
- Returns:
Gradients with respect to all forward inputs.
- Return type:
- layer_norm_linear_fn(x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-06, prenorm=False, residual_in_fp32=False, is_rms_norm=False)[source]¶
Applies fused Layer/RMS Normalization directly followed by a Linear projection using Triton.
- Parameters:
x (torch.Tensor) – Input tensor.
norm_weight (torch.Tensor) – Normalization weights.
norm_bias (torch.Tensor | None) – Normalization biases.
linear_weight (torch.Tensor) – Linear projection weights.
linear_bias (torch.Tensor | None) – Linear projection biases.
residual (torch.Tensor, optional) – Optional residual tensor to add before norm. Defaults to None.
eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.
prenorm (bool, optional) – Whether to return the pre-normalized (residual) state. Defaults to False.
residual_in_fp32 (bool, optional) – Maintain residual in FP32. Defaults to False.
is_rms_norm (bool, optional) – If True, computes RMS norm instead of Layer norm. Defaults to False.
- Returns:
The projected output. If
prenorm=True, returns(out, prenorm_state).- Return type: