From b6062082afa8c6e638d9c4e55adb5fcf466a4f5a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 25 Nov 2024 15:45:12 +0100 Subject: [PATCH] default to `"auto"` dtype --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4703c415e42fbb..8fe3ab5b1f9749 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3273,7 +3273,7 @@ def from_pretrained( `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However, this should still be enabled if you are passing in a `device_map`. - torch_dtype (`str` or `torch.dtype`, *optional*): + torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`): Override the default `torch.dtype` and load the model under a specific `dtype`. The different options are: @@ -3407,7 +3407,7 @@ def from_pretrained( from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _fast_init = kwargs.pop("_fast_init", True) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", "auto") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None)