diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 2ed6e8a5e7..2cf4c51a72 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -605,11 +605,11 @@ def test_malformed_data( @pytest.mark.parametrize('model_max_length', [1024, 2048]) @pytest.mark.parametrize('padding_side', ['left', 'right']) @pytest.mark.parametrize('add_decoder_input_ids', [True, False]) -def test_token_counting_func(pad_token_id: Optional[int], batch_size: int, +def test_token_counting_func(pad_token_id: int, batch_size: int, model_max_length: int, padding_side: str, add_decoder_input_ids: bool): gptt = transformers.AutoTokenizer.from_pretrained('gpt2') - gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id + gptt.pad_token_id = pad_token_id gptt.model_max_length = model_max_length gptt.padding_side = padding_side