diff --git a/engines/python/setup/djl_python/transformers_neuronx.py b/engines/python/setup/djl_python/transformers_neuronx.py index ed682f638..2cf781b4c 100644 --- a/engines/python/setup/djl_python/transformers_neuronx.py +++ b/engines/python/setup/djl_python/transformers_neuronx.py @@ -180,12 +180,7 @@ def load_model(self, model_type): # TODO: workaround on Neuron Compiler bug for SM path = os.getcwd() os.chdir("/tmp") - if model_type == "gpt2": - self.model._load_compiled_artifacts(load_path) - self.model.to_neuron() - self.model._save_compiled_artifacts(load_path) - else: - self.model.to_neuron() + self.model.to_neuron() os.chdir(path) elapsed = time.time() - start logging.info(