diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index 42e944d550b..ce220c1b8e4 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -174,7 +174,19 @@ def get_xpu_available_memory(device_index: int): from intel_extension_for_pytorch.xpu import mem_get_info return mem_get_info(device_index)[0] + elif version.parse(torch.__version__).release >= version.parse("2.6").release: + # torch.xpu.mem_get_info API is available starting from PyTorch 2.6 + # It further requires PyTorch built with the SYCL runtime which supports API + # to query available device memory. If not available, exception will be + # raised. Version of SYCL runtime used to build PyTorch is being reported + # with print(torch.version.xpu) and corresponds to the version of Intel DPC++ + # SYCL compiler. First version to support required feature is 20250001. + try: + return torch.xpu.mem_get_info(device_index)[0] + except Exception: + pass + warnings.warn( - "The XPU `mem_get_info` API is available in IPEX version >=2.5. The current returned available memory is incorrect. Please consider upgrading your IPEX version." + "The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version." ) return torch.xpu.max_memory_allocated(device_index)