Source code for lrnnx.ops.torch
from functools import partial
from typing import Any, Callable
import torch
[docs]
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool) -> Callable[..., Any]:
"""
Wrapper for Automatic Mixed Precision (AMP) decorators to handle deprecation.
PyTorch deprecated ``torch.cuda.amp`` in favor of ``torch.amp``. This decorator
ensures backward compatibility by injecting the ``device_type="cuda"`` keyword
argument into the new decorator if the deprecated version is no longer used.
Args:
dec (Callable): The original AMP decorator function (e.g., ``custom_fwd`` or ``custom_bwd``).
cuda_amp_deprecated (bool): A flag indicating whether the ``torch.cuda.amp`` module is deprecated.
Returns:
Callable: The wrapped decorator function.
"""
def decorator(*args: Any, **kwargs: Any) -> Any:
if cuda_amp_deprecated:
kwargs["device_type"] = "cuda"
return dec(*args, **kwargs)
return decorator
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
deprecated = True
from torch.amp import custom_bwd as _custom_bwd # type: ignore[attr-defined]
from torch.amp import custom_fwd as _custom_fwd # type: ignore[attr-defined]
else:
deprecated = False
from torch.cuda.amp import custom_bwd as _custom_bwd # type: ignore[assignment]
from torch.cuda.amp import custom_fwd as _custom_fwd # type: ignore[assignment]
custom_fwd: Callable[..., Any] = custom_amp_decorator(_custom_fwd, deprecated)
custom_bwd: Callable[..., Any] = custom_amp_decorator(_custom_bwd, deprecated)