Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu device is not used running pipeline(device_map="auto") #31922

Open
dvrogozh opened this issue Jul 12, 2024 · 5 comments · May be fixed by huggingface/accelerate#3275
Open

xpu device is not used running pipeline(device_map="auto") #31922

dvrogozh opened this issue Jul 12, 2024 · 5 comments · May be fixed by huggingface/accelerate#3275
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Jul 12, 2024

Found on this code versions: 5258501, huggingface/accelerate@12a007d, pytorch/pytorch@3477ee3. This is an issue with XPU support in stock pytorch (i.e. without using IPEX).

HF model pipelines with device_map="auto" (or device_map="sequential") does not actually run on XPU even if they can fit the device memory. I spotted that trying to run LLAMA 3 models:

Example script:

import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
    messages,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
print(outputs[0]["generated_text"][-1])

Workarounds and findings:

  • If model fits device memory, then changing device_map="auto" to device_map="xpu" will allow model to run (that's easier to check on 8B model)
  • Model starts to also work (but see a note below) if you add max_memory to the model kwargs:
model_kwargs={"torch_dtype": torch.bfloat16, "max_memory": {0: 5.0e+10}}, device_map="auto",
...
  File "/home/gta/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 118, in __getitem__
    return self.dataset[f"{self.prefix}{key}"]
  File "/home/gta/git/huggingface/accelerate/src/accelerate/utils/offload.py", line 171, in __getitem__
    tensor = f.get_tensor(weight_info.get("weight_name", key))
  File "/home/gta/git/pytorch/pytorch/torch/cuda/__init__.py", line 305, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 @sywangyi @yao-matrix

@dvrogozh
Copy link
Contributor Author

I think that issue root cause is that HF currently can't query total free XPU device memory and consequently does not use XPU for dispatching. I should note however that it seems there is an issue in HF accelerate around query free XPU memory - wrong function is used for the query and HF gets allocated memory size instead of free memory size. See huggingface/accelerate#2929.

@amyeroberts
Copy link
Collaborator

cc @muellerzr @SunMarc

@dvrogozh
Copy link
Contributor Author

I forgot to comment that I opened request in pytorch to support the torch.xpu.mem_get_info() API to allow Huggingface device_map=auto mode to work. Basically the root cause of this issue is inability to query for free/total XPU device memory.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@SunMarc SunMarc reopened this Aug 20, 2024
@muellerzr muellerzr added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Aug 21, 2024
dvrogozh added a commit to dvrogozh/accelerate that referenced this issue Dec 6, 2024
torch.xpu.mem_get_info API is available starting from PyTorch 2.6 (and
in nightly 2.6.0.dev20241206+xpu or later). To work properly this method
requires PyTorch built with the SYCL runtime which supports API to query
device memory stats. If not available, exception will be raised.

Requires: pytorch/pytorch#141230
Fixes: huggingface#2929
Fixes: huggingface/transformers#31922
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/accelerate that referenced this issue Dec 6, 2024
torch.xpu.mem_get_info API is available starting from PyTorch 2.6 (and
in nightly 2.6.0.dev20241206+xpu or later). To work properly this method
requires PyTorch built with the SYCL runtime which supports API to query
device memory stats. If not available, exception will be raised.

Requires: pytorch/pytorch#141230
Fixes: huggingface#2929
Fixes: huggingface/transformers#31922
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh
Copy link
Contributor Author

dvrogozh commented Dec 6, 2024

torch.xpu.mem_get_info() API has landed in PyTorch this week (thru pytorch/pytorch#141230) making it for PyTorch 2.6 upcoming release. Here is a corresponding fix on Accelerate side which addresses the issue:

dvrogozh added a commit to dvrogozh/accelerate that referenced this issue Dec 9, 2024
torch.xpu.mem_get_info API is available starting from PyTorch 2.6 (and
in nightly 2.6.0.dev20241206+xpu or later). To work properly this method
requires PyTorch built with the SYCL runtime which supports API to query
device memory stats. If not available, exception will be raised.

Requires: pytorch/pytorch#141230
Fixes: huggingface#2929
Fixes: huggingface/transformers#31922
Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants