Skip to content

Commit

Permalink
linting is shortening my lifespan
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 5, 2024
1 parent d7f4c98 commit 700d0ce
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions tests/data/test_data_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from llmfoundry.data import StreamingTextDataset
from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset

@pytest.mark.parametrize('token_encoding_type', ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])

@pytest.mark.parametrize(
'token_encoding_type',
['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'],
)
@pytest.mark.parametrize('use_bytes', [True, False])
@pytest.mark.parametrize('samples', [10])
@pytest.mark.parametrize('max_seq_len', [2048])
Expand Down Expand Up @@ -61,14 +65,17 @@ def test_encoding_types_text(
local=dataset_local_path,
batch_size=1,
)

for _, sample in enumerate(dataset):
# StreamingTextDataset returns a torch Tensor, not numpy array
assert sample.dtype == getattr(torch, token_encoding_type)
assert sample.shape == (max_seq_len,)


@pytest.mark.parametrize('token_encoding_type', ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'])
@pytest.mark.parametrize(
'token_encoding_type',
['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'],
)
@pytest.mark.parametrize('use_bytes', [True, False])
@pytest.mark.parametrize('samples', [10])
@pytest.mark.parametrize('max_seq_len', [2048])
Expand Down Expand Up @@ -136,12 +143,16 @@ def test_encoding_types_finetuning(
assert isinstance(sample['turns'][0]['labels'][0], int)
assert len(sample['turns'][0]['labels']) == max_seq_len


@pytest.mark.parametrize(
'token_encoding_type',
['int17', 'float32', 'complex', 'int4'],
)
@pytest.mark.parametrize('use_finetuning', [True, False])
def test_unsupported_encoding_type(token_encoding_type: str, use_finetuning: bool):
def test_unsupported_encoding_type(
token_encoding_type: str,
use_finetuning: bool,
):
with pytest.raises(ValueError, match='The token_encoding_type*'):
if use_finetuning:
StreamingFinetuningDataset(
Expand Down

0 comments on commit 700d0ce

Please sign in to comment.