Skip to content

Commit

Permalink
Optimize loading of falcon model checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
blozano-tt committed Dec 11, 2024
1 parent 2abd3fe commit 9099753
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion models/demos/ttnn_falcon7b/tests/test_falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_model_prefix(layer_index: int = 0):
@pytest.fixture(scope="module")
def torch_model():
hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto"
).eval()
state_dict = hugging_face_reference_model.state_dict()
filtered_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix())
Expand Down
2 changes: 1 addition & 1 deletion models/demos/ttnn_falcon7b/tests/test_falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_model_prefix(layer_index: int = 0):
@pytest.fixture(scope="module")
def torch_model():
hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto"
).eval()
state_dict = hugging_face_reference_model.state_dict()
mlp_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix())
Expand Down
2 changes: 1 addition & 1 deletion models/demos/ttnn_falcon7b/tests/test_falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_model_prefix(layer_index: int = 0):
@pytest.fixture(scope="module")
def torch_model():
hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto"
).eval()
state_dict = hugging_face_reference_model.state_dict()
mlp_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_model_prefix(layer_index: int = 0):
@pytest.fixture(scope="module")
def torch_model():
hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained(
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True
PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto"
).eval()
state_dict = hugging_face_reference_model.state_dict()
filtered_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix())
Expand Down

0 comments on commit 9099753

Please sign in to comment.