Skip to content

Commit

Permalink
fix type
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Dec 14, 2023
1 parent 801db4c commit 55dbc6f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,11 @@ def test_malformed_data(
@pytest.mark.parametrize('model_max_length', [1024, 2048])
@pytest.mark.parametrize('padding_side', ['left', 'right'])
@pytest.mark.parametrize('add_decoder_input_ids', [True, False])
def test_token_counting_func(pad_token_id: Optional[int], batch_size: int,
def test_token_counting_func(pad_token_id: int, batch_size: int,
model_max_length: int, padding_side: str,
add_decoder_input_ids: bool):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id
gptt.pad_token_id = pad_token_id
gptt.model_max_length = model_max_length
gptt.padding_side = padding_side

Expand Down

0 comments on commit 55dbc6f

Please sign in to comment.