Skip to content

Commit

Permalink
[CI] Fix bnb quantization tests with accelerate>=1.2.0 (#35172)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas authored Dec 9, 2024
1 parent fa8763c commit 34f4080
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,14 @@ def test_inference_without_keep_in_fp32(self):

# test with `google-t5/t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_4bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)
T5ForConditionalGeneration._keep_in_fp32_modules = modules

Expand All @@ -410,14 +410,14 @@ def test_inference_with_keep_in_fp32(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))

encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_4bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)


Expand Down
12 changes: 6 additions & 6 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,14 @@ def test_inference_without_keep_in_fp32(self):

# test with `google-t5/t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)
T5ForConditionalGeneration._keep_in_fp32_modules = modules

Expand All @@ -540,14 +540,14 @@ def test_inference_with_keep_in_fp32(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))

encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

def test_inference_with_keep_in_fp32_serialized(self):
Expand All @@ -571,14 +571,14 @@ def test_inference_with_keep_in_fp32_serialized(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))

encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)

# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
_ = model.generate(**encoded_input)


Expand Down

0 comments on commit 34f4080

Please sign in to comment.