Skip to content

Commit

Permalink
Merge branch 'main' into milo/harbor-checkpointer
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Dec 17, 2024
2 parents 4301ceb + 3269c73 commit bdcf051
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
28 changes: 28 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,3 +1161,31 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict:
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}


@dataset_constructor.register('math-ai/StackMathQA')
def QA_format_preprocessor(inp: dict) -> ChatFormattedDict:
"""Convert from QA format to our chat format."""
try:
Q = inp['Q']
A = inp['A']
messages: list[dict[str, str]] = [{
'role': 'user',
'content': Q,
}, {
'role': 'assistant',
'content': A,
}]
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}


@dataset_constructor.register('AI-MO/NuminaMath-CoT')
def messages_format_preprocessor(inp: dict) -> ChatFormattedDict:
"""Convert from QA format to our chat format."""
try:
messages = inp['messages']
except Exception as e:
raise UnableToProcessPromptResponseError(inp) from e
return {'messages': messages}
47 changes: 47 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pytest

from llmfoundry.data.finetuning.tasks import (
QA_format_preprocessor,
_get_num_processes,
dataset_constructor,
messages_format_preprocessor,
)
from llmfoundry.utils.exceptions import DatasetTooSmallError

Expand Down Expand Up @@ -60,3 +62,48 @@ def get_local_world_size(self):
new=MockDataset,
):
dataset_constructor.build_from_streaming()


def test_QA_format_preprocessor():
inp = {
'Q': 'What is the capital of France?',
'A': 'Paris',
'meta': {
'a': 'b',
},
}

expected_messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
output = QA_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
for i, message in enumerate(output['messages']):
expected_message = expected_messages[i]
for k, v in message.items():
assert k in expected_message
assert v == expected_message[k]


def test_messages_format_preprocessor():
messages = [{
'role': 'user',
'content': 'What is the capital of France?',
}, {
'role': 'assistant',
'content': 'Paris',
}]
inp = {
'messages': messages,
'other_key': 'other_value',
}

output = messages_format_preprocessor(inp)
assert len(output) == 1
assert 'messages' in output
assert output['messages'] == messages

0 comments on commit bdcf051

Please sign in to comment.