From 700d0ce310bb065875ec1551cc4c7a4afb7ffe73 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 5 Jun 2024 13:02:04 -0700 Subject: [PATCH] linting is shortening my lifespan --- tests/data/test_data_encodings.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/data/test_data_encodings.py b/tests/data/test_data_encodings.py index 8d971281db..7350731599 100644 --- a/tests/data/test_data_encodings.py +++ b/tests/data/test_data_encodings.py @@ -10,7 +10,11 @@ from llmfoundry.data import StreamingTextDataset from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset -@pytest.mark.parametrize('token_encoding_type', ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']) + +@pytest.mark.parametrize( + 'token_encoding_type', + ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], +) @pytest.mark.parametrize('use_bytes', [True, False]) @pytest.mark.parametrize('samples', [10]) @pytest.mark.parametrize('max_seq_len', [2048]) @@ -61,14 +65,17 @@ def test_encoding_types_text( local=dataset_local_path, batch_size=1, ) - + for _, sample in enumerate(dataset): # StreamingTextDataset returns a torch Tensor, not numpy array assert sample.dtype == getattr(torch, token_encoding_type) assert sample.shape == (max_seq_len,) -@pytest.mark.parametrize('token_encoding_type', ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64']) +@pytest.mark.parametrize( + 'token_encoding_type', + ['int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], +) @pytest.mark.parametrize('use_bytes', [True, False]) @pytest.mark.parametrize('samples', [10]) @pytest.mark.parametrize('max_seq_len', [2048]) @@ -136,12 +143,16 @@ def test_encoding_types_finetuning( assert isinstance(sample['turns'][0]['labels'][0], int) assert len(sample['turns'][0]['labels']) == max_seq_len + @pytest.mark.parametrize( 'token_encoding_type', ['int17', 'float32', 'complex', 'int4'], ) @pytest.mark.parametrize('use_finetuning', [True, False]) -def test_unsupported_encoding_type(token_encoding_type: str, use_finetuning: bool): +def test_unsupported_encoding_type( + token_encoding_type: str, + use_finetuning: bool, +): with pytest.raises(ValueError, match='The token_encoding_type*'): if use_finetuning: StreamingFinetuningDataset(