Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error if using TorchAO quantizer when using device_map with sharded checkpoint #10256

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

When passing a device_map, whether as a finegrained dict or simply "auto", we cannot use the torchao quantization method currently.

Error
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump5.py", line 91, in <module>
    transformer = FluxTransformer2DModel.from_pretrained(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/aryan/work/diffusers/src/diffusers/models/modeling_utils.py", line 920, in from_pretrained
    accelerate.load_checkpoint_and_dispatch(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch
    load_checkpoint_in_model(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 1690, in load_checkpoint_in_model
    if os.path.isfile(checkpoint):
  File "/home/aryan/.pyenv/versions/3.10.14/lib/python3.10/genericpath.py", line 30, in isfile
    st = os.stat(path)
TypeError: stat: path should be string, bytes, os.PathLike or integer, not dict

This is a quick workaround to throw a cleaner error message. The reason why we error out is because:

Related:

Accelerate also provides load_checkpoint_in_model which might be usable here since we are working with a state dict here. Until we can figure out the best way to support this, let's raise a clean error. We can tackle in #10013 and work on refactoring too. model_file does not make sense as a variable name either when holding a state dict, which caused some confusions during debugging.

The missing case was found by @DN6, thanks! This was not detected by our fast/slow tests or me during testing device_map related changes because I was using unsharded single safetensors file for both SDXL and Flux. If the state dict is unsharded, device_map should work just fine whether you pass a string like auto or balanced, or if you pass a finegrained dict.

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu December 17, 2024 06:12
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder or "",
)
if hf_quantizer is not None:
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO
Copy link
Collaborator

@yiyixuxu yiyixuxu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)

is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"

this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error

from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch

sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16

# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
#     model_id,
#     subfolder="transformer",
#     torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")

torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)

print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path,
    quantization_config=torch_ao_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)

print(f"torchao hf_device_map: {transformer.hf_device_map}")

transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path, 
    quantization_config=bnb_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")


print(f" testing sharded checkpoint")
## sharded checkpoint
try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id, 
        subfolder="transformer",
        quantization_config=torch_ao_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
    )
    print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id,
        subfolder="transformer",
        quantization_config=bnb_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
)
    print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:

method/shard sharded non-sharded
torchao fails works
bnb fails fails

I tried with your code as well and get the following error when using BnB with unsharded on this branch:

NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.

Whatever the automatic infer of device_map thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.

Consolidating the checks together sounds good. Will update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants