Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 5, 2023
1 parent e6de5e7 commit 4a32f7a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 32 deletions.
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(

# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = config.model_type.replace("_", "-")
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names:
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. "
"We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support."
Expand Down
60 changes: 29 additions & 31 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,25 +1967,22 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.1

def test_inference_old_onnx_model(self):
@parameterized.expand([(False,), (True,)])
def test_inference_old_onnx_model(self, use_cache):
model_id = "optimum/gpt2"
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = get_preprocessor(model_id)
text = "This is a sample output"
tokens = tokenizer(text, return_tensors="pt")
onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache)

for use_cache in (True, False):
onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache)

self.assertEqual(onnx_model.use_cache, use_cache)
self.assertEqual(
onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME
)
outputs_onnx = onnx_model.generate(
**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30
)
outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30)
self.assertTrue(torch.allclose(outputs_onnx, outputs))
self.assertEqual(onnx_model.use_cache, use_cache)
self.assertEqual(onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME)
outputs_onnx = onnx_model.generate(
**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30
)
outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30)
self.assertTrue(torch.allclose(outputs_onnx, outputs))

def test_load_model_from_hub_onnx(self):
model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-without-merge")
Expand Down Expand Up @@ -2023,15 +2020,14 @@ def test_merge_from_onnx_and_save(self, model_arch):

self.assertTrue(model.use_merged)
self.assertIsInstance(model.model, onnxruntime.InferenceSession)

model.save_pretrained(tmpdir + "_save")
save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME)
self.assertTrue(has_onnx_input(save_path, "use_cache_branch"))

folder_contents = os.listdir(tmpdir + "_save")
self.assertTrue(ONNX_DECODER_NAME not in folder_contents)
self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents)
self.assertTrue(ONNX_WEIGHTS_NAME not in folder_contents)
self.assertNotIn(ONNX_DECODER_NAME, folder_contents)
self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents)
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)

@parameterized.expand(grid_parameters(FULL_GRID))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
Expand Down Expand Up @@ -2268,38 +2264,40 @@ def test_compare_with_and_without_past_key_values(self, model_arch):

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool):
model_args = {
"test_name": test_name + "_True",
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": True,
}
self._setup(model_args)
model_args = {
"test_name": test_name + "_False",
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": False,
}
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"]
model_merged_dir = self.onnx_model_dirs[test_name + "_True"]

model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir)
not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME)
self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch"))
self.assertFalse(model_not_merged.use_merged)

model_merged_dir = Path(model_not_merged_dir) / "merged"
task = model_not_merged.export_feature
if use_cache:
task += "-with-past"

main_export(
model_id,
output=model_merged_dir,
task=task,
no_post_process=False,
legacy=True,
)

model_merged = ORTModelForCausalLM.from_pretrained(model_merged_dir)
merged_onnx_path = Path(model_merged_dir, ONNX_WEIGHTS_NAME)
self.assertFalse(has_onnx_input(merged_onnx_path, "use_cache_branch"))
self.assertFalse(model_merged.use_merged)
merged_onnx_path = Path(model_merged_dir, ONNX_DECODER_MERGED_NAME)
self.assertTrue(has_onnx_input(merged_onnx_path, "use_cache_branch"))
self.assertTrue(model_merged.use_merged)

outputs_model_not_merged = model_not_merged.generate(**tokens)
outputs_model_merged = model_merged.generate(**tokens)
Expand Down

0 comments on commit 4a32f7a

Please sign in to comment.