lrnnx.ops.torch module

custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool) Callable[[...], Any][source]

Wrapper for Automatic Mixed Precision (AMP) decorators to handle deprecation.

PyTorch deprecated torch.cuda.amp in favor of torch.amp. This decorator ensures backward compatibility by injecting the device_type="cuda" keyword argument into the new decorator if the deprecated version is no longer used.

Parameters:
  • dec (Callable) – The original AMP decorator function (e.g., custom_fwd or custom_bwd).

  • cuda_amp_deprecated (bool) – A flag indicating whether the torch.cuda.amp module is deprecated.

Returns:

The wrapped decorator function.

Return type:

Callable

custom_fwd() Any
custom_bwd() Any