-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise an error if using TorchAO quantizer when using device_map with sharded checkpoint #10256
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
subfolder=subfolder or "", | |||
) | |||
if hf_quantizer is not None: | |||
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" |
this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch
sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16
# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
# model_id,
# subfolder="transformer",
# torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")
torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)
print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
single_model_path,
quantization_config=torch_ao_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"torchao hf_device_map: {transformer.hf_device_map}")
transformer = FluxTransformer2DModel.from_pretrained(
single_model_path,
quantization_config=bnb_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")
print(f" testing sharded checkpoint")
## sharded checkpoint
try:
transformer = FluxTransformer2DModel.from_pretrained(
sharded_model_id,
subfolder="transformer",
quantization_config=torch_ao_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
print(f"error: {e}")
try:
transformer = FluxTransformer2DModel.from_pretrained(
sharded_model_id,
subfolder="transformer",
quantization_config=bnb_quantization_config,
device_map="auto",
torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
print(f"error: {e}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:
method/shard | sharded | non-sharded |
---|---|---|
torchao | fails | works |
bnb | fails | fails |
I tried with your code as well and get the following error when using BnB with unsharded on this branch:
NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.
Whatever the automatic infer of device_map
thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.
Consolidating the checks together sounds good. Will update
When passing a device_map, whether as a finegrained dict or simply
"auto"
, we cannot use the torchao quantization method currently.Error
This is a quick workaround to throw a cleaner error message. The reason why we error out is because:
diffusers/src/diffusers/models/modeling_utils.py
Line 806 in 7ca64fd
Here, we merge the sharded checkpoints because hf_quantizer is not None and set
is_sharded
to False. This causesmodel_file
to be a state dict instead of a string.diffusers/src/diffusers/models/modeling_utils.py
Line 914 in 7ca64fd
Accelerate expects a file path when
load_checkpoint_and_dispatch
is called, but we try to pass a state dict.Related:
Accelerate also provides
load_checkpoint_in_model
which might be usable here since we are working with a state dict here. Until we can figure out the best way to support this, let's raise a clean error. We can tackle in #10013 and work on refactoring too.model_file
does not make sense as a variable name either when holding a state dict, which caused some confusions during debugging.The missing case was found by @DN6, thanks! This was not detected by our fast/slow tests or me during testing
device_map
related changes because I was using unsharded single safetensors file for both SDXL and Flux. If the state dict is unsharded,device_map
should work just fine whether you pass a string likeauto
orbalanced
, or if you pass a finegrained dict.