diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 73ada66544f..2b7e5fcaa19 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1101,6 +1101,7 @@ def _init_infer_auto_device_map( special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, ) -> Tuple[ List[Union[int, str]], + Dict[Union[int, str], Union[int, str]], List[Union[int, str]], List[int], Dict[str, int], @@ -1147,6 +1148,7 @@ def _init_infer_auto_device_map( return ( devices, + max_memory, main_devices, gpus, module_sizes, @@ -1356,6 +1358,7 @@ def infer_auto_device_map( # Initialize the variables ( devices, + max_memory, main_devices, gpus, module_sizes,