diff --git a/src/accelerate/test_utils/other.py b/src/accelerate/test_utils/other.py index e9ee04306ef..6ec1e2d9b63 100644 --- a/src/accelerate/test_utils/other.py +++ b/src/accelerate/test_utils/other.py @@ -66,14 +66,14 @@ def backend_is_available(device: str): # Update device function dict mapping -def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): +def update_mapping_from_spec(dispatch_table: Dict[str, Callable], attribute_name: str): try: # Try to import the function directly spec_fn = getattr(device_spec_module, attribute_name) - device_fn_dict[torch_device] = spec_fn + dispatch_table[torch_device] = spec_fn except AttributeError as e: # If the function doesn't exist, and there is no default, throw an error - if "default" not in device_fn_dict: + if "default" not in dispatch_table: raise AttributeError( f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." ) from e