diff --git a/prompting/__init__.py b/prompting/__init__.py index 78e13264..8033bdbe 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -36,4 +36,29 @@ from . import agent from . import conversation from . import dendrite + from .llms import hf + +from .tasks import TASKS +from .tools import DATASETS +from .task_registry import TASK_REGISTRY + +# Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries. +registry_missing_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) +registry_extra_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) +assert ( + not registry_missing_task +), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" +assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + +registry_datasets = set( + [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] +) +registry_missing_dataset = registry_datasets - set(DATASETS.keys()) +registry_extra_dataset = set(DATASETS.keys()) - registry_datasets +assert ( + not registry_missing_dataset +), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}" +assert ( + not registry_extra_dataset +), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}" diff --git a/prompting/conversation.py b/prompting/conversation.py index 36256fd4..e17f5cd3 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -1,73 +1,50 @@ -from prompting.tasks import ( - Task, - DebuggingTask, - QuestionAnsweringTask, - SummarizationTask, - MathTask, - DateQuestionAnsweringTask, -) -from prompting.tools import ( - WikiDataset, - HFCodingDataset, - MathDataset, - WikiDateDataset, -) - +import random from transformers import Pipeline - - -def create_task(llm_pipeline: Pipeline, task_name: str, create_reference=True) -> Task: - wiki_based_tasks = ["summarization", "qa"] - coding_based_tasks = ["debugging"] - # TODO: Abstract dataset classes into common dynamic interface - if task_name in wiki_based_tasks: - dataset = WikiDataset() - - elif task_name in coding_based_tasks: - dataset = HFCodingDataset() - - elif task_name == "math": - dataset = MathDataset() - - elif task_name == "date_qa": - dataset = WikiDateDataset() - - if task_name == "summarization": - task = SummarizationTask( - llm_pipeline=llm_pipeline, - context=dataset.next(), - create_reference=create_reference, - ) - - elif task_name == "qa": - task = QuestionAnsweringTask( - llm_pipeline=llm_pipeline, - context=dataset.next(), - create_reference=create_reference, - ) - - elif task_name == "debugging": - task = DebuggingTask( - llm_pipeline=llm_pipeline, - context=dataset.next(), - create_reference=create_reference, - ) - - elif task_name == "math": - task = MathTask( - llm_pipeline=llm_pipeline, - context=dataset.next(), - create_reference=create_reference, - ) - - elif task_name == "date_qa": - task = DateQuestionAnsweringTask( - llm_pipeline=llm_pipeline, - context=dataset.next(), - create_reference=create_reference, - ) - +from prompting.tasks import Task, TASKS +from prompting.tools import Selector, DATASETS +from prompting.task_registry import TASK_REGISTRY + + +def create_task( + llm_pipeline: Pipeline, + task_name: str, + create_reference: bool = True, + selector: Selector = random.choice, +) -> Task: + """Create a task from the given task name and LLM pipeline. + + Args: + llm_pipeline (Pipeline): Pipeline to use for text generation + task_name (str): Name of the task to create + create_reference (bool, optional): Generate text for task reference answer upon creation. Defaults to True. + selector (Selector, optional): Selector function to choose a dataset. Defaults to random.choice. + + Raises: + ValueError: If task_name is not a valid alias for a task, or if the task is not a subclass of Task + ValueError: If no datasets are available for the given task + ValueError: If the dataset for the given task is not found + + Returns: + Task: Task instance + """ + + task = TASKS.get(task_name, None) + if task is None or not issubclass(task, Task): + raise ValueError(f"Task {task_name} not found") + + dataset_choices = TASK_REGISTRY.get(task_name, None) + if len(dataset_choices) == 0: + raise ValueError(f"No datasets available for task {task_name}") + + dataset_name = selector(dataset_choices) + dataset = DATASETS.get(dataset_name, None) + if dataset is None: + raise ValueError(f"Dataset {dataset_name} not found") else: - raise ValueError(f"Task {task_name} not supported. Please choose a valid task") + dataset = dataset() - return task + return task( + llm_pipeline=llm_pipeline, + context=dataset.next(), + create_reference=create_reference, + ) \ No newline at end of file diff --git a/prompting/task_registry.py b/prompting/task_registry.py new file mode 100644 index 00000000..66236c90 --- /dev/null +++ b/prompting/task_registry.py @@ -0,0 +1,19 @@ +from .tasks import Task, MockTask, SummarizationTask, QuestionAnsweringTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset + +# TODO: Expand this to include extra information beyond just the task and dataset names +mock_task, mock_dataset = MockTask.name, [MockDataset.name] +summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name] +qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name] +debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] +math_task, math_dataset = MathTask.name, [MathDataset.name] +date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name] + +TASK_REGISTRY = { + mock_task: mock_dataset, + summarization_task: summarization_dataset, + qa_task: qa_dataset, + debugging_task: debugging_dataset, + math_task: math_dataset, + date_qa_task: date_qa_dataset, +} \ No newline at end of file diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index d913935c..fa8115a3 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -5,12 +5,15 @@ from .date_qa import DateQuestionAnsweringTask from .generic_instruction import GenericInstructionTask from .math import MathTask +from .mock import MockTask TASKS = { - "qa": QuestionAnsweringTask, - "summarization": SummarizationTask, - "date_qa": DateQuestionAnsweringTask, - "debugging": DebuggingTask, - "math": MathTask, + MockTask.name: MockTask, + QuestionAnsweringTask.name: QuestionAnsweringTask, + DateQuestionAnsweringTask.name: DateQuestionAnsweringTask, + SummarizationTask.name: SummarizationTask, + DebuggingTask.name: DebuggingTask, + #GenericInstructionTask.name: GenericInstructionTask, + MathTask.name: MathTask, } diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index 0c70e2de..0f13015e 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -8,7 +8,7 @@ @dataclass class DateQuestionAnsweringTask(Task): - name = "date-based question answering" + name = "date_qa" desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ diff --git a/prompting/tasks/mock.py b/prompting/tasks/mock.py new file mode 100644 index 00000000..668827fb --- /dev/null +++ b/prompting/tasks/mock.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from prompting.tasks import Task + +@dataclass +class MockTask(Task): + name = "mock" + desc = "get help solving a math problem" + goal = "to get the answer to the following math question" + + reward_definition = [ + dict(name="float_diff", weight=1.0), + ] + penalty_definition = [] + + static_reference = True + static_query = True + + def __init__(self, llm_pipeline, context, create_reference=True): + self.context = context + + self.query = ( + "How can I solve the following problem, " + + context.content + + "?" + ) + self.reference = "This is the reference answer" + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags \ No newline at end of file diff --git a/prompting/tasks/qa.py b/prompting/tasks/qa.py index c2192c5c..41430f5d 100644 --- a/prompting/tasks/qa.py +++ b/prompting/tasks/qa.py @@ -40,7 +40,7 @@ @dataclass class QuestionAnsweringTask(Task): - name = "question-answering" + name = "qa" desc = "get help on answering a question" goal = "to get the answer to the following question" diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index a8e395c2..6d7e5b41 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -9,3 +9,15 @@ MathDataset, ) from .selector import Selector + +DATASETS = { + "mock": MockDataset, + "hf_coding": HFCodingDataset, + "wiki": WikiDataset, + #"stack_overflow": StackOverflowDataset, + "wiki_date": WikiDateDataset, + "math": MathDataset, +} + + + \ No newline at end of file diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index ab4b6d1c..55cce07f 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -28,7 +28,7 @@ class Dataset(ABC): """Base class for datasets.""" - + name = "dataset" max_tries: int = 10 @abstractmethod diff --git a/prompting/tools/datasets/code.py b/prompting/tools/datasets/code.py index ec086026..bdc2948b 100644 --- a/prompting/tools/datasets/code.py +++ b/prompting/tools/datasets/code.py @@ -523,6 +523,7 @@ def filter_comments(code, language): # TODO: why not define the chain_in, chain_out logic in the class itself? class HFCodingDataset(Dataset): + name = "hf_coding" def __init__( self, dataset_id="codeparrot/github-code", @@ -615,6 +616,7 @@ def get_special_contents(self, code, language, remove_comments=True): class StackOverflowDataset: + name = "stack_overflow" def __init__(self): # Stack Overflow API endpoint for a random article self.url = "https://api.stackexchange.com/2.3/questions" diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py index af3096d4..e4eae92e 100644 --- a/prompting/tools/datasets/math.py +++ b/prompting/tools/datasets/math.py @@ -30,6 +30,7 @@ class MathDataset(Dataset): + name = 'math' topics_list = mathgenerator.getGenList() def __init__(self, seed=None): @@ -56,36 +57,38 @@ def get( Dict: _description_ """ bt.logging.info(f"Getting math problem {name!r}") - info = mathgenerator.generate_context(name, **kwargs) - if info["reward_type"] != "float" or info["topic"] == "computer_science": - return None - - math_words = [ - "math", - "mathematics", - "mathematical", - "math problem", - "math technique", - ] - external_links = [] - # construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word - # binary_to_decimal -> ['binary to', 'to decimal'] - for bigram in itertools.combinations(info["forward_words"], 2): - words = list(bigram) + [random.choice(math_words)] - # shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem' - external_links.append(" ".join(random.sample(words, len(words)))) - - return { - "title": info["topic"], # title of math problem - "topic": info["topic"], # title of problem topic - "subtopic": info["subtopic"], # title of problem subtopic - "content": info["problem"], # problem statement - "internal_links": [info["topic"], info["subtopic"]], # internal links - "external_links": external_links, - "tags": info["forward_words"], - "source": "Mathgenerator", - "extra": {"reward_type": info["reward_type"], "solution": info["solution"]}, - } + max_tries = 10 + for _ in range(max_tries): + info = mathgenerator.generate_context(name, **kwargs) + if info["reward_type"] != "float" or info["topic"] == "computer_science": + pass + else: + math_words = [ + "math", + "mathematics", + "mathematical", + "math problem", + "math technique", + ] + external_links = [] + # construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word + # binary_to_decimal -> ['binary to', 'to decimal'] + for bigram in itertools.combinations(info["forward_words"], 2): + words = list(bigram) + [random.choice(math_words)] + # shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem' + external_links.append(" ".join(random.sample(words, len(words)))) + + return { + "title": info["topic"], # title of math problem + "topic": info["topic"], # title of problem topic + "subtopic": info["subtopic"], # title of problem subtopic + "content": info["problem"], # problem statement + "internal_links": [info["topic"], info["subtopic"]], # internal links + "external_links": external_links, + "tags": info["forward_words"], + "source": "Mathgenerator", + "extra": {"reward_type": info["reward_type"], "solution": info["solution"]}, + } def search( self, name, selector: Selector, include: List = None, exclude: List = None diff --git a/prompting/tools/datasets/mock.py b/prompting/tools/datasets/mock.py index 89e4d501..54008269 100644 --- a/prompting/tools/datasets/mock.py +++ b/prompting/tools/datasets/mock.py @@ -4,6 +4,7 @@ class MockDataset(Dataset): + name = "mock" def get(self, name, exclude=None, selector=None): return { "title": name, diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index dd516fd6..6da057f6 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -151,7 +151,7 @@ def filter_categories(categories, exclude=None, include=None): class WikiDataset(Dataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" - + name = "wiki" EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") @@ -240,6 +240,7 @@ def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Di class WikiDateDataset(Dataset): + name = "wiki_date" INCLUDE_HEADERS = ("Events", "Births", "Deaths") MONTHS = ( "January", diff --git a/prompting/utils/config.py b/prompting/utils/config.py index ae27921d..f21b89a9 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -21,6 +21,7 @@ import argparse import bittensor as bt from loguru import logger +from prompting.tasks import TASKS def check_config(cls, config: "bt.Config"): @@ -277,7 +278,7 @@ def add_validator_args(cls, parser): type=str, nargs="+", help="The tasks to use for the validator.", - default=["summarization", "qa", "debugging", "math", "date_qa"], + default=list(TASKS.keys())[1:], ) parser.add_argument( @@ -285,7 +286,7 @@ def add_validator_args(cls, parser): type=float, nargs="+", help="The probability of sampling each task.", - default=[.25, .25, 0, .25, .25], + default=[.25, .25, .25, 0, .25], ) parser.add_argument( diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 00000000..49c72fd2 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,26 @@ +from prompting.tasks import TASKS +from prompting.tools import DATASETS +from prompting.task_registry import TASK_REGISTRY + + +# TODO: Create more detailed tests. +def test_task_registry(): + registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) + registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) + assert ( + not registry_missing_task + ), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" + assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + +def test_task_registry_datasets(): + registry_datasets = set( + [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] + ) + registry_missing_dataset = registry_datasets - set(DATASETS.keys()) + registry_extra_dataset = set(DATASETS.keys()) - registry_datasets + assert ( + not registry_missing_dataset + ), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}" + assert ( + not registry_extra_dataset + ), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}"