Skip to content

Commit

Permalink
adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 17, 2024
1 parent 6367a43 commit f4c2399
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pytest

from llmfoundry.data.finetuning.tasks import (
QA_format_preprocessor,
_get_num_processes,
dataset_constructor,
messages_format_preprocessor,
)
from llmfoundry.utils.exceptions import DatasetTooSmallError

Expand Down Expand Up @@ -60,3 +62,48 @@ def get_local_world_size(self):
new=MockDataset,
):
dataset_constructor.build_from_streaming()


def test_QA_format_preprocessor():
inp = {
'Q': 'What is the capital of France?',
'A': 'Paris',
'meta': {
'a': 'b',
},
}

expected_messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
output = QA_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
for i, message in enumerate(output['messages']):
expected_message = expected_messages[i]
for k, v in message.items():
assert k in expected_message
assert v == expected_message[k]


def test_messages_format_preprocessor():
messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
inp = {
'messages': messages,
'other_key': 'other_value',
}

output = messages_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
assert output['messages'] == messages

0 comments on commit f4c2399

Please sign in to comment.