diff --git a/models/demos/ttnn_falcon7b/tests/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/test_falcon_attention.py index 89512cca80f..8cd2c7aa483 100644 --- a/models/demos/ttnn_falcon7b/tests/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/test_falcon_attention.py @@ -33,7 +33,7 @@ def get_model_prefix(layer_index: int = 0): @pytest.fixture(scope="module") def torch_model(): hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained( - PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True + PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto" ).eval() state_dict = hugging_face_reference_model.state_dict() filtered_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix()) diff --git a/models/demos/ttnn_falcon7b/tests/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/test_falcon_decoder.py index 243a2a64a79..5e2a769f0ee 100644 --- a/models/demos/ttnn_falcon7b/tests/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/test_falcon_decoder.py @@ -33,7 +33,7 @@ def get_model_prefix(layer_index: int = 0): @pytest.fixture(scope="module") def torch_model(): hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained( - PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True + PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto" ).eval() state_dict = hugging_face_reference_model.state_dict() mlp_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix()) diff --git a/models/demos/ttnn_falcon7b/tests/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/test_falcon_mlp.py index 7ff91a2b76d..ef22acba48c 100644 --- a/models/demos/ttnn_falcon7b/tests/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/test_falcon_mlp.py @@ -25,7 +25,7 @@ def get_model_prefix(layer_index: int = 0): @pytest.fixture(scope="module") def torch_model(): hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained( - PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True + PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto" ).eval() state_dict = hugging_face_reference_model.state_dict() mlp_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix()) diff --git a/models/demos/ttnn_falcon7b/tests/test_falcon_rotary_embedding.py b/models/demos/ttnn_falcon7b/tests/test_falcon_rotary_embedding.py index 29a50fd1bf8..c90bad66da1 100644 --- a/models/demos/ttnn_falcon7b/tests/test_falcon_rotary_embedding.py +++ b/models/demos/ttnn_falcon7b/tests/test_falcon_rotary_embedding.py @@ -29,7 +29,7 @@ def get_model_prefix(layer_index: int = 0): @pytest.fixture(scope="module") def torch_model(): hugging_face_reference_model = transformers.FalconForCausalLM.from_pretrained( - PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True + PRETRAINED_MODEL_NAME, low_cpu_mem_usage=True, device_map="auto" ).eval() state_dict = hugging_face_reference_model.state_dict() filtered_state_dict = strip_state_dict_prefix(state_dict, get_model_prefix())