diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index ba17730f9d9..246179e9860 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -2127,6 +2127,7 @@ def get_model_from_task( use_auth_token = model_kwargs.pop("use_auth_token", None) token = model_kwargs.pop("token", None) trust_remote_code = model_kwargs.pop("trust_remote_code", False) + model_kwargs["torch_dtype"] = torch_dtype if use_auth_token is not None: warnings.warn( @@ -2144,6 +2145,7 @@ def get_model_from_task( token=token, revision=revision, trust_remote_code=trust_remote_code, + model_kwargs=model_kwargs, ) else: try: