From 55dbc6fd1b06da4f63a7e9099c7a9c7c93c30f61 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 13 Dec 2023 16:28:31 -0800 Subject: [PATCH] fix type --- tests/data/test_dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 2ed6e8a5e7..2cf4c51a72 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -605,11 +605,11 @@ def test_malformed_data( @pytest.mark.parametrize('model_max_length', [1024, 2048]) @pytest.mark.parametrize('padding_side', ['left', 'right']) @pytest.mark.parametrize('add_decoder_input_ids', [True, False]) -def test_token_counting_func(pad_token_id: Optional[int], batch_size: int, +def test_token_counting_func(pad_token_id: int, batch_size: int, model_max_length: int, padding_side: str, add_decoder_input_ids: bool): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') - gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id + gptt.pad_token_id = pad_token_id gptt.model_max_length = model_max_length gptt.padding_side = padding_side