Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 5, 2024
1 parent 7efd189 commit 91ad932
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions tests/data/test_text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,43 @@
# SPDX-License-Identifier: Apache-2.0
import os
import pathlib
from streaming import MDSWriter
from llmfoundry.data import StreamingTextDataset
import pytest

import numpy as np
import pytest
import torch
from streaming import MDSWriter

from llmfoundry.data import StreamingTextDataset


@pytest.mark.parametrize('token_encoding_type', ['int16', 'int32', 'int64'])
@pytest.mark.parametrize('samples', [10])
@pytest.mark.parametrize('max_seq_len', [2048])
@pytest.mark.parametrize('vocab_size', [10000])
def test_encoding_types(tmp_path: pathlib.Path,
token_encoding_type: str,
samples: int,
max_seq_len: int,
vocab_size: int):
def test_encoding_types(
tmp_path: pathlib.Path,
token_encoding_type: str,
samples: int,
max_seq_len: int,
vocab_size: int,
):
dataset_local_path = str(tmp_path)
encoding_dtype = getattr(np, token_encoding_type)

columns = {
'tokens': 'ndarray:'+token_encoding_type,
'tokens': 'ndarray:' + token_encoding_type,
}

with MDSWriter(out=dataset_local_path, columns=columns) as writer:
for _ in range(samples):
tokens = np.random.randint(0, vocab_size, max_seq_len, dtype=encoding_dtype)
tokens = np.random.randint(
0,
vocab_size,
max_seq_len,
dtype=encoding_dtype,
)
writer.write({'tokens': tokens})

print('Dataset local path:', dataset_local_path)
print(os.listdir(dataset_local_path))

Expand All @@ -45,7 +55,11 @@ def test_encoding_types(tmp_path: pathlib.Path,
assert sample.dtype == getattr(torch, token_encoding_type)
assert sample.shape == (max_seq_len,)

@pytest.mark.parametrize('token_encoding_type', ['int17', 'float32', 'complex', 'int8'])

@pytest.mark.parametrize(
'token_encoding_type',
['int17', 'float32', 'complex', 'int8'],
)
def test_unsupported_encoding_type(token_encoding_type: str):
with pytest.raises(ValueError, match='The token_encoding_type*'):
StreamingTextDataset(
Expand All @@ -54,4 +68,4 @@ def test_unsupported_encoding_type(token_encoding_type: str):
max_seq_len=2048,
local='dataset/path',
batch_size=1,
)
)

0 comments on commit 91ad932

Please sign in to comment.