Skip to content

Commit

Permalink
finish all todos (#1957)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Sep 12, 2023
1 parent 5dec654 commit 61a87ab
Showing 1 changed file with 0 additions and 21 deletions.
21 changes: 0 additions & 21 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,22 +638,16 @@ def test_dispatch_model_bnb(self):
with init_empty_weights():
model = AutoModel.from_config(AutoConfig.from_pretrained("bigscience/bloom-560m"))

# TODO: @younesbelkada remove the positional arg on the next `transformers` release
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = replace_with_bnb_linear(
model, modules_to_not_convert=["lm_head"], quantization_config=quantization_config
)

# TODO: @younesbelkada remove this block on the next `transformers` release
for p in model.parameters():
p.requires_grad = False

model_path = hf_hub_download("bigscience/bloom-560m", "pytorch_model.bin")

model = load_checkpoint_and_dispatch(
model,
checkpoint=model_path,
# device_map="auto",
device_map="balanced",
)

Expand All @@ -674,16 +668,11 @@ def test_dispatch_model_int8_simple(self):
with init_empty_weights():
model = AutoModel.from_config(AutoConfig.from_pretrained("bigscience/bloom-560m"))

# TODO: @younesbelkada remove the positional arg on the next `transformers` release
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = replace_with_bnb_linear(
model, modules_to_not_convert=["lm_head"], quantization_config=quantization_config
)

# TODO: @younesbelkada remove this block on the next `transformers` release
for p in model.parameters():
p.requires_grad = False

model_path = hf_hub_download("bigscience/bloom-560m", "pytorch_model.bin")

# test with auto
Expand All @@ -699,14 +688,10 @@ def test_dispatch_model_int8_simple(self):
with init_empty_weights():
model = AutoModel.from_config(AutoConfig.from_pretrained("bigscience/bloom-560m"))

# TODO: @younesbelkada remove the positional arg on the next `transformers` release
model = replace_with_bnb_linear(
model, modules_to_not_convert=["lm_head"], quantization_config=quantization_config
)

for p in model.parameters():
p.requires_grad = False

# test with str device map
model = load_checkpoint_and_dispatch(
model,
Expand All @@ -720,15 +705,10 @@ def test_dispatch_model_int8_simple(self):
with init_empty_weights():
model = AutoModel.from_config(AutoConfig.from_pretrained("bigscience/bloom-560m"))

# TODO: @younesbelkada remove the positional arg on the next `transformers` release
model = replace_with_bnb_linear(
model, modules_to_not_convert=["lm_head"], quantization_config=quantization_config
)

# TODO: @younesbelkada remove this block on the next `transformers` release
for p in model.parameters():
p.requires_grad = False

# test with torch.device device map
model = load_checkpoint_and_dispatch(
model,
Expand All @@ -741,7 +721,6 @@ def test_dispatch_model_int8_simple(self):

@slow
@require_bnb
@unittest.skip("Un-skip in the next transformers release")
def test_dipatch_model_fp4_simple(self):
"""Tests that `dispatch_model` quantizes fp4 layers"""
from huggingface_hub import hf_hub_download
Expand Down

0 comments on commit 61a87ab

Please sign in to comment.