diff --git a/src/accelerate/test_utils/other.py b/src/accelerate/test_utils/other.py index 3a0d314a810..e9ee04306ef 100644 --- a/src/accelerate/test_utils/other.py +++ b/src/accelerate/test_utils/other.py @@ -53,17 +53,7 @@ # This dispatches a defined function according to the hardware accelerator from the function definitions. def device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): - if device not in dispatch_table: - return dispatch_table["default"](*args, **kwargs) - - fn = dispatch_table[device] - - # Some device agnostic functions return values. Need to guard against 'None' instead at - # user level - if fn is None: - return None - - return fn(*args, **kwargs) + return dispatch_table.get(device, dispatch_table["default"])(*args, **kwargs) # These are callables which automatically dispatch the function specific to the hardware accelerator