lrnnx.ops.torch module¶
- custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool) Callable[[...], Any][source]¶
Wrapper for Automatic Mixed Precision (AMP) decorators to handle deprecation.
PyTorch deprecated
torch.cuda.ampin favor oftorch.amp. This decorator ensures backward compatibility by injecting thedevice_type="cuda"keyword argument into the new decorator if the deprecated version is no longer used.- Parameters:
dec (Callable) – The original AMP decorator function (e.g.,
custom_fwdorcustom_bwd).cuda_amp_deprecated (bool) – A flag indicating whether the
torch.cuda.ampmodule is deprecated.
- Returns:
The wrapped decorator function.
- Return type:
Callable