Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2 CasualLM Inference crashes when using transformers v4.39.0 #6991

Closed
sechkova opened this issue Apr 29, 2024 · 6 comments · Fixed by #7010
Closed

GPT2 CasualLM Inference crashes when using transformers v4.39.0 #6991

sechkova opened this issue Apr 29, 2024 · 6 comments · Fixed by #7010

Comments

@sechkova
Copy link

sechkova commented Apr 29, 2024

🐛 Bug

When running LLM inference with gpt2 model using HF transformers , upgrading to transformers v4.39.0 leads to the following error:

F0000 00:00:1714398323.905552  103900 debug_macros.h:20] Non-OK-status: status.status() status: INVALID_ARGUMENT: Expected pred or integral type in argument to and/or operation; got F32.

To Reproduce

Example test code:


import os

import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ["PJRT_DEVICE"] = "CPU"

tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("My name is", return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained("gpt2", use_cache=False)

device = xm.xla_device(devkind="CPU")
inputs.to(device)
model.to(device)

gen_tokens = model.generate(**inputs, max_new_tokens=2)
decoded = tokenizer.batch_decode(gen_tokens)
print(decoded[0])

Expected behavior

Reverting to the previous version transformers==4.38.0 fixes the errors and the inference runs fine.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch 2.3.0
  • torch-xla 2.3.0
  • transformers 4.39.0

Additional context

Full error log is in the comments.

@sechkova
Copy link
Author

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714399208.839010  106661 cpu_client.cc:405] TfrtCpuClient created.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
F0000 00:00:1714399209.178748  106661 debug_macros.h:20] Non-OK-status: status.status() status: INVALID_ARGUMENT: Expected pred or integral type in argument to and/or operation; got F32.
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
	torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
	torch_xla::XlaHelpers::TypeOfXlaOp(xla::XlaOp)
	torch_xla::XlaHelpers::PromotedBinaryOp(xla::XlaOp, xla::XlaOp, std::function<xla::XlaOp (xla::XlaOp, xla::XlaOp)> const&)

	torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)

	torch_xla::BitwiseOrTensorOutputShape(torch::lazy::Value const&, torch::lazy::Value const&)
	std::_Function_handler<xla::Shape (), torch_xla::BitwiseOrTensor::BitwiseOrTensor(torch::lazy::Value const&, torch::lazy::Value const&)::{lambda()#1}>::_M_invoke(std::_Any_data const&)
	torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
	torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
	torch_xla::tensor_methods::bitwise_or(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
	torch_xla::XLANativeFunctions::bitwise_or(at::Tensor const&, at::Tensor const&)

	c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const




	at::_ops::bitwise_or_Tensor::call(at::Tensor const&, at::Tensor const&)
	at::native::__or__(at::Tensor const&, at::Tensor const&)

	at::_ops::__or___Tensor::call(at::Tensor const&, at::Tensor const&)







	_PyEval_EvalFrameDefault
	_PyObject_FastCallDictTstate
	_PyObject_Call_Prepend

	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault

	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault

	PyObject_Call
	_PyEval_EvalFrameDefault

	PyEval_EvalCode



	_PyRun_SimpleFileObject
	_PyRun_AnyFileObject
	Py_RunMain
	Py_BytesMain

	__libc_start_main
	_start
*** End stack trace ***

*** Check failure stack trace: ***
    @     0x7fbf12be15d9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7fbf0ad92834  ConsumeValue<>()
    @     0x7fbf0ad9289e  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x7fbf0aa0d5c9  torch_xla::XlaHelpers::TypeOfXlaOp()
    @     0x7fbf0aa15f67  torch_xla::XlaHelpers::PromotedBinaryOp()
    @     0x7fbf0ad4a1c1  std::_Function_handler<>::_M_invoke()
    @     0x7fbf0ad05ba9  torch_xla::InferOutputShape()
    @     0x7fbf0ad50373  torch_xla::(anonymous namespace)::InferBinaryOpShape()
    @     0x7fbf0ad5056e  torch_xla::BitwiseOrTensorOutputShape()
    @     0x7fbf0aa5f119  std::_Function_handler<>::_M_invoke()
    @     0x7fbf0ad862a6  torch_xla::XlaNode::GetOpShape()
    @     0x7fbf0ad86b99  torch_xla::XlaNode::XlaNode()
    @     0x7fbf0aa72eaf  torch_xla::tensor_methods::bitwise_or()
    @     0x7fbf0a9a9101  torch_xla::XLANativeFunctions::bitwise_or()
    @     0x7fbf0ac11419  c10::impl::make_boxed_from_unboxed_functor<>::call()
    @     0x7fbfcd328028  c10::Dispatcher::callBoxed()
    @     0x5652c3acaf80  (unknown)
Aborted (core dumped)

@sechkova sechkova changed the title GPT2 CasualLM Inference crashes when using transformers v3.39.0 GPT2 CasualLM Inference crashes when using transformers v4.39.0 Apr 29, 2024
@JackCaoG
Copy link
Collaborator

I was able to repo it with

>>> import torch
>>> import torch_xla
>>> t1 = torch.randn(3,3, device='xla:0')
>>> t1
tensor([[ 1.1453, -0.9900,  0.5783],
        [ 0.6340,  1.6611,  0.2455],
        [-1.1664,  0.5326,  1.7286]], device='xla:0')
>>> torch.bitwise_or(t1, t1)

but this is also not supported in pytorch

>>> import torch
>>> t1 = torch.randn(3,3)
>>> torch.bitwise_or(t1, t1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: "bitwise_or_cpu" not implemented for 'Float'

does above codes works in pytorch native gpu env?

@sechkova
Copy link
Author

@JackCaoG interesting ...
Yes, my example above works in pytorch native gpu or cpu env ( device = "cuda" or device = "cpu").
I get this error only when I use xla_device + transformers > 4.38.

I saw these changes huggingface/transformers#29334 were part of the v4.39.0 release. Do you think it can be related?

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 1, 2024

yea.. I can repo this issue, let me look into it a bit..

@JackCaoG
Copy link
Collaborator

JackCaoG commented May 1, 2024

Ah ok, the issue is from torch.full in

class StoppingCriteriaList(list):
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
        for criteria in self:
            is_done = is_done | criteria(input_ids, scores, **kwargs)
        return is_done
(Pdb) torch.full((input_ids.shape[0],), False, device=input_ids.device)
tensor([0.], device='xla:0')
(Pdb) torch.full((input_ids.shape[0],), False, device="cpu")
tensor([False])

Let me fix it...

@JackCaoG
Copy link
Collaborator

This should be fixed now, let me close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants