From c8404cd59aa58a200f7119203839cbbc145a9143 Mon Sep 17 00:00:00 2001 From: Steffen Cruz Date: Sun, 17 Mar 2024 16:20:04 -0600 Subject: [PATCH 01/18] Add dataset registry --- prompting/tools/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index a8e395c2..6b2911f6 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -9,3 +9,12 @@ MathDataset, ) from .selector import Selector + +DATASETS = { + "mock": MockDataset, + "hf_coding": HFCodingDataset, + "wiki": WikiDataset, + "stack_overflow": StackOverflowDataset, + "wiki_date": WikiDateDataset, + "math": MathDataset, +} From d285e083d397b5885df5795adda084abac5c6a50 Mon Sep 17 00:00:00 2001 From: Steffen Cruz Date: Sun, 17 Mar 2024 16:20:24 -0600 Subject: [PATCH 02/18] Add task-dataset registry, which will enable multi-dataset tasks --- prompting/__init__.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/prompting/__init__.py b/prompting/__init__.py index 585b4160..3d4f34d8 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -37,3 +37,34 @@ from . import conversation from . import dendrite from . import llm + +from tasks import TASKS +from tools import DATASETS + +# TODO: Expand this to include extra information beyond just the task and dataset names +TASK_REGISTRY = { + "summarization": ["wiki"], + "qa": ["wiki"], + "debugging": ["hf_coding"], + "math": ["math"], + "date_qa": ["wiki_date"], +} +# Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries. +registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) +registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) +assert ( + not registry_missing_task +), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" +assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + +registry_datasets = set( + [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] +) +registry_missing_dataset = registry_datasets - set(DATASETS.keys()) +registry_extra_dataset = set(DATASETS.keys()) - registry_datasets +assert ( + not registry_missing_dataset +), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}" +assert ( + not registry_extra_dataset +), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}" From 21beebc9371e45ac3bd43b1fccfafdd983b8dbbd Mon Sep 17 00:00:00 2001 From: Steffen Cruz Date: Sun, 17 Mar 2024 16:20:37 -0600 Subject: [PATCH 03/18] Use registry for task creation --- prompting/conversation.py | 98 +++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 56 deletions(-) diff --git a/prompting/conversation.py b/prompting/conversation.py index 263d65db..a90cffa7 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -1,57 +1,43 @@ -from prompting.tasks import ( - Task, - DebuggingTask, - QuestionAnsweringTask, - SummarizationTask, - MathTask, - DateQuestionAnsweringTask, -) -from prompting.tools import ( - WikiDataset, - HFCodingDataset, - MathDataset, - WikiDateDataset, -) - +import random from transformers import Pipeline - - -def create_task(llm_pipeline: Pipeline, task_name: str) -> Task: - wiki_based_tasks = ["summarization", "qa"] - coding_based_tasks = ["debugging"] - # TODO Add math and date_qa to this structure - - # TODO: Abstract dataset classes into common dynamic interface - if task_name in wiki_based_tasks: - dataset = WikiDataset() - - elif task_name in coding_based_tasks: - dataset = HFCodingDataset() - - elif task_name == "math": - dataset = MathDataset() - - elif task_name == "date_qa": - dataset = WikiDateDataset() - - if task_name == "summarization": - task = SummarizationTask(llm_pipeline=llm_pipeline, context=dataset.next()) - - elif task_name == "qa": - task = QuestionAnsweringTask(llm_pipeline=llm_pipeline, context=dataset.next()) - - elif task_name == "debugging": - task = DebuggingTask(llm_pipeline=llm_pipeline, context=dataset.next()) - - elif task_name == "math": - task = MathTask(llm_pipeline=llm_pipeline, context=dataset.next()) - - elif task_name == "date_qa": - task = DateQuestionAnsweringTask( - llm_pipeline=llm_pipeline, context=dataset.next() - ) - - else: - raise ValueError(f"Task {task_name} not supported. Please choose a valid task") - - return task +from prompting import TASK_REGISTRY, TASKS, DATASETS +from prompting.tasks import Task + + +def create_task( + llm_pipeline: Pipeline, task_name: str, create_reference: bool = True +) -> Task: + """Create a task from the given task name and LLM pipeline. + + Args: + llm_pipeline (Pipeline): Pipeline to use for text generation + task_name (str): Name of the task to create + create_reference (bool, optional): Generate text for task reference answer upon creation. Defaults to True. + + Raises: + ValueError: If task_name is not a valid alias for a task, or if the task is not a subclass of Task + ValueError: If no datasets are available for the given task + ValueError: If the dataset for the given task is not found + + Returns: + Task: Task instance + """ + + task = TASKS.get(task_name, None) + if task is None or not issubclass(task, Task): + raise ValueError(f"Task {task_name} not found") + + dataset_choices = TASK_REGISTRY.get(task_name, []) + if len(dataset_choices) == 0: + raise ValueError(f"No datasets available for task {task_name}") + + dataset_name = random.choice(dataset_choices) + dataset = DATASETS.get(dataset_name, None) + if dataset is None: + raise ValueError(f"Dataset {dataset_name} not found") + + return task( + llm_pipeline=llm_pipeline, + context=dataset.next(), + create_reference=create_reference, + ) From 58bdbd849260fe86f8b4f7457d1781d174dd2e58 Mon Sep 17 00:00:00 2001 From: Steffen Cruz Date: Sun, 17 Mar 2024 16:24:57 -0600 Subject: [PATCH 04/18] Add arbitrary selector for improved control --- prompting/conversation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/prompting/conversation.py b/prompting/conversation.py index a90cffa7..63c02f9d 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -2,10 +2,14 @@ from transformers import Pipeline from prompting import TASK_REGISTRY, TASKS, DATASETS from prompting.tasks import Task +from prompting.tools import Selector def create_task( - llm_pipeline: Pipeline, task_name: str, create_reference: bool = True + llm_pipeline: Pipeline, + task_name: str, + create_reference: bool = True, + selector: Selector = random.choice, ) -> Task: """Create a task from the given task name and LLM pipeline. @@ -13,6 +17,7 @@ def create_task( llm_pipeline (Pipeline): Pipeline to use for text generation task_name (str): Name of the task to create create_reference (bool, optional): Generate text for task reference answer upon creation. Defaults to True. + selector (Selector, optional): Selector function to choose a dataset. Defaults to random.choice. Raises: ValueError: If task_name is not a valid alias for a task, or if the task is not a subclass of Task @@ -31,7 +36,7 @@ def create_task( if len(dataset_choices) == 0: raise ValueError(f"No datasets available for task {task_name}") - dataset_name = random.choice(dataset_choices) + dataset_name = selector(dataset_choices) dataset = DATASETS.get(dataset_name, None) if dataset is None: raise ValueError(f"Dataset {dataset_name} not found") From 35878b05c02cc9d41b7fe90e46ea928929a6f103 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:08:13 +0000 Subject: [PATCH 05/18] Resolve circular imports --- prompting/conversation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/prompting/conversation.py b/prompting/conversation.py index de9d321e..00218c0c 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -1,8 +1,7 @@ import random from transformers import Pipeline -from prompting import TASK_REGISTRY, TASKS, DATASETS -from prompting.tasks import Task -from prompting.tools import Selector +from prompting.tasks import Task, TASK_REGISTRY, TASKS +from prompting.tools import Selector, DATASETS def create_task( From 35a53b1d01a4b99ba4088ce7db28f0218451201e Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:14:28 +0000 Subject: [PATCH 06/18] Resolve circular imports in prompting/conversation.py --- prompting/conversation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/prompting/conversation.py b/prompting/conversation.py index 00218c0c..b3643c8e 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -1,7 +1,8 @@ import random from transformers import Pipeline -from prompting.tasks import Task, TASK_REGISTRY, TASKS +from prompting.tasks import Task, TASKS from prompting.tools import Selector, DATASETS +from prompting import TASK_REGISTRY def create_task( From df1885dbd6859864bb095d25f2c2114b1eae9370 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:26:14 +0000 Subject: [PATCH 07/18] Create separate task_registry file --- prompting/__init__.py | 9 +-------- prompting/conversation.py | 2 +- prompting/task_registry.py | 8 ++++++++ 3 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 prompting/task_registry.py diff --git a/prompting/__init__.py b/prompting/__init__.py index f51e87ac..38568119 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -41,15 +41,8 @@ from tasks import TASKS from tools import DATASETS +from task_registry import TASK_REGISTRY -# TODO: Expand this to include extra information beyond just the task and dataset names -TASK_REGISTRY = { - "summarization": ["wiki"], - "qa": ["wiki"], - "debugging": ["hf_coding"], - "math": ["math"], - "date_qa": ["wiki_date"], -} # Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries. registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) diff --git a/prompting/conversation.py b/prompting/conversation.py index b3643c8e..aad84ad7 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -2,7 +2,7 @@ from transformers import Pipeline from prompting.tasks import Task, TASKS from prompting.tools import Selector, DATASETS -from prompting import TASK_REGISTRY +from prompting.task_registry import TASK_REGISTRY def create_task( diff --git a/prompting/task_registry.py b/prompting/task_registry.py new file mode 100644 index 00000000..5525b021 --- /dev/null +++ b/prompting/task_registry.py @@ -0,0 +1,8 @@ +# TODO: Expand this to include extra information beyond just the task and dataset names +TASK_REGISTRY = { + "summarization": ["wiki"], + "qa": ["wiki"], + "debugging": ["hf_coding"], + "math": ["math"], + "date_qa": ["wiki_date"], +} \ No newline at end of file From dfb65a2328788b7e35aadab9acb2d14c0f03aa12 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:39:09 +0000 Subject: [PATCH 08/18] Remove llm import from init --- prompting/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/__init__.py b/prompting/__init__.py index 38568119..f56cb878 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -36,7 +36,7 @@ from . import agent from . import conversation from . import dendrite -from . import llm + from .llms import hf from tasks import TASKS From cd9f5926784599f20aaaea90d1285249cd135863 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:46:38 +0000 Subject: [PATCH 09/18] Import TASKS and DATASETS --- prompting/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prompting/__init__.py b/prompting/__init__.py index f56cb878..74cbd311 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -39,9 +39,9 @@ from .llms import hf -from tasks import TASKS -from tools import DATASETS -from task_registry import TASK_REGISTRY +from .tasks import TASKS +from .tools import DATASETS +from .task_registry import TASK_REGISTRY # Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries. registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) From dc00d3b82dee16928b8802b0ebaf1db4c25b0010 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 15:58:39 +0000 Subject: [PATCH 10/18] Add Mock Task to Registry --- prompting/task_registry.py | 1 + prompting/tasks/__init__.py | 1 + prompting/tools/__init__.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/prompting/task_registry.py b/prompting/task_registry.py index 5525b021..187985f8 100644 --- a/prompting/task_registry.py +++ b/prompting/task_registry.py @@ -1,5 +1,6 @@ # TODO: Expand this to include extra information beyond just the task and dataset names TASK_REGISTRY = { + "mock": ["mock"], "summarization": ["wiki"], "qa": ["wiki"], "debugging": ["hf_coding"], diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index d913935c..eded446d 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -8,6 +8,7 @@ TASKS = { + "mock": Task, "qa": QuestionAnsweringTask, "summarization": SummarizationTask, "date_qa": DateQuestionAnsweringTask, diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index 6b2911f6..d88c6a7e 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -14,7 +14,7 @@ "mock": MockDataset, "hf_coding": HFCodingDataset, "wiki": WikiDataset, - "stack_overflow": StackOverflowDataset, + #"stack_overflow": StackOverflowDataset, "wiki_date": WikiDateDataset, "math": MathDataset, } From a1eb07bf95d31061b35eec398cdd680187ffa723 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 8 Apr 2024 16:13:31 +0000 Subject: [PATCH 11/18] Add registry unit tests --- tests/test_registry.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_registry.py diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 00000000..7cfe6f30 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,26 @@ +from prompting.tasks import TASKS +from prompting.tools import DATASETS +from prompting.task_registry import TASK_REGISTRY + + +# TODO: Improve more detailed tasks. +def test_task_registry(): + registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) + registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) + assert ( + not registry_missing_task + ), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" + assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + +def test_task_registry_datasets(): + registry_datasets = set( + [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] + ) + registry_missing_dataset = registry_datasets - set(DATASETS.keys()) + registry_extra_dataset = set(DATASETS.keys()) - registry_datasets + assert ( + not registry_missing_dataset + ), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}" + assert ( + not registry_extra_dataset + ), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}" \ No newline at end of file From b807cccf51ddfd44265cb8d30d3adfa9ee52d3be Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Tue, 9 Apr 2024 14:55:12 +0000 Subject: [PATCH 12/18] Instantiate the datasets before task creation --- prompting/conversation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prompting/conversation.py b/prompting/conversation.py index aad84ad7..5c8f830d 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -40,6 +40,8 @@ def create_task( dataset = DATASETS.get(dataset_name, None) if dataset is None: raise ValueError(f"Dataset {dataset_name} not found") + else: + dataset = dataset() return task( llm_pipeline=llm_pipeline, From 6f2dda4c523b32fb0bce2394f64b89832927e29f Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Wed, 10 Apr 2024 14:04:10 +0000 Subject: [PATCH 13/18] Use class.name rather than hardcoded strings --- prompting/__init__.py | 8 +++++--- prompting/task_registry.py | 22 ++++++++++++++++------ prompting/tasks/__init__.py | 7 ++++--- prompting/tasks/mock.py | 29 +++++++++++++++++++++++++++++ prompting/tools/__init__.py | 3 +++ prompting/tools/datasets/base.py | 2 +- prompting/tools/datasets/code.py | 2 ++ prompting/tools/datasets/math.py | 1 + prompting/tools/datasets/mock.py | 1 + prompting/tools/datasets/wiki.py | 3 ++- 10 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 prompting/tasks/mock.py diff --git a/prompting/__init__.py b/prompting/__init__.py index 74cbd311..a72ff6ef 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -44,16 +44,18 @@ from .task_registry import TASK_REGISTRY # Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries. -registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) -registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) +registry_missing_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) +registry_extra_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) assert ( not registry_missing_task ), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" registry_datasets = set( - [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] + [datasets for task, datasets in TASK_REGISTRY.items()] ) +print(registry_datasets) +print(set(DATASETS.keys())) registry_missing_dataset = registry_datasets - set(DATASETS.keys()) registry_extra_dataset = set(DATASETS.keys()) - registry_datasets assert ( diff --git a/prompting/task_registry.py b/prompting/task_registry.py index 187985f8..2a122d89 100644 --- a/prompting/task_registry.py +++ b/prompting/task_registry.py @@ -1,9 +1,19 @@ +from .tasks import Task, MockTask, SummarizationTask, QuestionAnsweringTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset + # TODO: Expand this to include extra information beyond just the task and dataset names +mock_task, mock_dataset = MockTask.name, MockDataset.name +summarization_task, summarization_dataset = SummarizationTask.name, WikiDataset.name +qa_task, qa_dataset = QuestionAnsweringTask.name, WikiDataset.name +debugging_task, debugging_dataset = DebuggingTask.name, HFCodingDataset.name +math_task, math_dataset = MathTask.name, MathDataset.name +date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, WikiDateDataset.name + TASK_REGISTRY = { - "mock": ["mock"], - "summarization": ["wiki"], - "qa": ["wiki"], - "debugging": ["hf_coding"], - "math": ["math"], - "date_qa": ["wiki_date"], + mock_task: mock_dataset, + summarization_task: summarization_dataset, + qa_task: qa_dataset, + debugging_task: debugging_dataset, + math_task: math_dataset, + date_qa_task: date_qa_dataset, } \ No newline at end of file diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index eded446d..2f24307f 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -5,13 +5,14 @@ from .date_qa import DateQuestionAnsweringTask from .generic_instruction import GenericInstructionTask from .math import MathTask +from .mock import MockTask TASKS = { - "mock": Task, - "qa": QuestionAnsweringTask, + "mock": MockTask, + "question-answering": QuestionAnsweringTask, "summarization": SummarizationTask, - "date_qa": DateQuestionAnsweringTask, + "date-based question answering": DateQuestionAnsweringTask, "debugging": DebuggingTask, "math": MathTask, } diff --git a/prompting/tasks/mock.py b/prompting/tasks/mock.py new file mode 100644 index 00000000..668827fb --- /dev/null +++ b/prompting/tasks/mock.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from prompting.tasks import Task + +@dataclass +class MockTask(Task): + name = "mock" + desc = "get help solving a math problem" + goal = "to get the answer to the following math question" + + reward_definition = [ + dict(name="float_diff", weight=1.0), + ] + penalty_definition = [] + + static_reference = True + static_query = True + + def __init__(self, llm_pipeline, context, create_reference=True): + self.context = context + + self.query = ( + "How can I solve the following problem, " + + context.content + + "?" + ) + self.reference = "This is the reference answer" + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags \ No newline at end of file diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index d88c6a7e..6d7e5b41 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -18,3 +18,6 @@ "wiki_date": WikiDateDataset, "math": MathDataset, } + + + \ No newline at end of file diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index ab4b6d1c..55cce07f 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -28,7 +28,7 @@ class Dataset(ABC): """Base class for datasets.""" - + name = "dataset" max_tries: int = 10 @abstractmethod diff --git a/prompting/tools/datasets/code.py b/prompting/tools/datasets/code.py index ec086026..bdc2948b 100644 --- a/prompting/tools/datasets/code.py +++ b/prompting/tools/datasets/code.py @@ -523,6 +523,7 @@ def filter_comments(code, language): # TODO: why not define the chain_in, chain_out logic in the class itself? class HFCodingDataset(Dataset): + name = "hf_coding" def __init__( self, dataset_id="codeparrot/github-code", @@ -615,6 +616,7 @@ def get_special_contents(self, code, language, remove_comments=True): class StackOverflowDataset: + name = "stack_overflow" def __init__(self): # Stack Overflow API endpoint for a random article self.url = "https://api.stackexchange.com/2.3/questions" diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py index af3096d4..ca2513a3 100644 --- a/prompting/tools/datasets/math.py +++ b/prompting/tools/datasets/math.py @@ -30,6 +30,7 @@ class MathDataset(Dataset): + name = 'math' topics_list = mathgenerator.getGenList() def __init__(self, seed=None): diff --git a/prompting/tools/datasets/mock.py b/prompting/tools/datasets/mock.py index 89e4d501..54008269 100644 --- a/prompting/tools/datasets/mock.py +++ b/prompting/tools/datasets/mock.py @@ -4,6 +4,7 @@ class MockDataset(Dataset): + name = "mock" def get(self, name, exclude=None, selector=None): return { "title": name, diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index dd516fd6..6da057f6 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -151,7 +151,7 @@ def filter_categories(categories, exclude=None, include=None): class WikiDataset(Dataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" - + name = "wiki" EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") @@ -240,6 +240,7 @@ def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Di class WikiDateDataset(Dataset): + name = "wiki_date" INCLUDE_HEADERS = ("Events", "Births", "Deaths") MONTHS = ( "January", From 1fb74fe4977ad28dd812414433642339974efb91 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Wed, 10 Apr 2024 17:22:41 +0000 Subject: [PATCH 14/18] Rename date_qa --- prompting/tasks/date_qa.py | 2 +- prompting/utils/config.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index 0c70e2de..0f13015e 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -8,7 +8,7 @@ @dataclass class DateQuestionAnsweringTask(Task): - name = "date-based question answering" + name = "date_qa" desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ diff --git a/prompting/utils/config.py b/prompting/utils/config.py index ae27921d..8de754af 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -21,6 +21,7 @@ import argparse import bittensor as bt from loguru import logger +from prompting.task_registry import TASK_REGISTRY def check_config(cls, config: "bt.Config"): @@ -277,7 +278,7 @@ def add_validator_args(cls, parser): type=str, nargs="+", help="The tasks to use for the validator.", - default=["summarization", "qa", "debugging", "math", "date_qa"], + default=TASK_REGISTRY.keys(), ) parser.add_argument( From 8cb53bd3997e094408027f3dfe0762ad6338b425 Mon Sep 17 00:00:00 2001 From: steffencruz Date: Wed, 10 Apr 2024 17:54:22 +0000 Subject: [PATCH 15/18] Update Registry --- prompting/__init__.py | 2 +- prompting/conversation.py | 2 +- prompting/task_registry.py | 12 ++++++------ prompting/tasks/__init__.py | 13 +++++++------ prompting/utils/config.py | 6 +++--- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/prompting/__init__.py b/prompting/__init__.py index a72ff6ef..9d7aee53 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -52,7 +52,7 @@ assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" registry_datasets = set( - [datasets for task, datasets in TASK_REGISTRY.items()] + [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] ) print(registry_datasets) print(set(DATASETS.keys())) diff --git a/prompting/conversation.py b/prompting/conversation.py index 5c8f830d..e17f5cd3 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -32,7 +32,7 @@ def create_task( if task is None or not issubclass(task, Task): raise ValueError(f"Task {task_name} not found") - dataset_choices = TASK_REGISTRY.get(task_name, []) + dataset_choices = TASK_REGISTRY.get(task_name, None) if len(dataset_choices) == 0: raise ValueError(f"No datasets available for task {task_name}") diff --git a/prompting/task_registry.py b/prompting/task_registry.py index 2a122d89..66236c90 100644 --- a/prompting/task_registry.py +++ b/prompting/task_registry.py @@ -2,12 +2,12 @@ from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset # TODO: Expand this to include extra information beyond just the task and dataset names -mock_task, mock_dataset = MockTask.name, MockDataset.name -summarization_task, summarization_dataset = SummarizationTask.name, WikiDataset.name -qa_task, qa_dataset = QuestionAnsweringTask.name, WikiDataset.name -debugging_task, debugging_dataset = DebuggingTask.name, HFCodingDataset.name -math_task, math_dataset = MathTask.name, MathDataset.name -date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, WikiDateDataset.name +mock_task, mock_dataset = MockTask.name, [MockDataset.name] +summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name] +qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name] +debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] +math_task, math_dataset = MathTask.name, [MathDataset.name] +date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name] TASK_REGISTRY = { mock_task: mock_dataset, diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index 2f24307f..fa8115a3 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -9,10 +9,11 @@ TASKS = { - "mock": MockTask, - "question-answering": QuestionAnsweringTask, - "summarization": SummarizationTask, - "date-based question answering": DateQuestionAnsweringTask, - "debugging": DebuggingTask, - "math": MathTask, + MockTask.name: MockTask, + QuestionAnsweringTask.name: QuestionAnsweringTask, + DateQuestionAnsweringTask.name: DateQuestionAnsweringTask, + SummarizationTask.name: SummarizationTask, + DebuggingTask.name: DebuggingTask, + #GenericInstructionTask.name: GenericInstructionTask, + MathTask.name: MathTask, } diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 8de754af..f21b89a9 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -21,7 +21,7 @@ import argparse import bittensor as bt from loguru import logger -from prompting.task_registry import TASK_REGISTRY +from prompting.tasks import TASKS def check_config(cls, config: "bt.Config"): @@ -278,7 +278,7 @@ def add_validator_args(cls, parser): type=str, nargs="+", help="The tasks to use for the validator.", - default=TASK_REGISTRY.keys(), + default=list(TASKS.keys())[1:], ) parser.add_argument( @@ -286,7 +286,7 @@ def add_validator_args(cls, parser): type=float, nargs="+", help="The probability of sampling each task.", - default=[.25, .25, 0, .25, .25], + default=[.25, .25, .25, 0, .25], ) parser.add_argument( From 9aa843ebb16a3bef290db70aa20a6bb5d08686d4 Mon Sep 17 00:00:00 2001 From: steffencruz Date: Wed, 10 Apr 2024 20:28:24 +0000 Subject: [PATCH 16/18] Update qa name --- prompting/tasks/qa.py | 2 +- prompting/tools/datasets/math.py | 62 ++++++++++++++++---------------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/prompting/tasks/qa.py b/prompting/tasks/qa.py index c2192c5c..41430f5d 100644 --- a/prompting/tasks/qa.py +++ b/prompting/tasks/qa.py @@ -40,7 +40,7 @@ @dataclass class QuestionAnsweringTask(Task): - name = "question-answering" + name = "qa" desc = "get help on answering a question" goal = "to get the answer to the following question" diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py index ca2513a3..e4eae92e 100644 --- a/prompting/tools/datasets/math.py +++ b/prompting/tools/datasets/math.py @@ -57,36 +57,38 @@ def get( Dict: _description_ """ bt.logging.info(f"Getting math problem {name!r}") - info = mathgenerator.generate_context(name, **kwargs) - if info["reward_type"] != "float" or info["topic"] == "computer_science": - return None - - math_words = [ - "math", - "mathematics", - "mathematical", - "math problem", - "math technique", - ] - external_links = [] - # construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word - # binary_to_decimal -> ['binary to', 'to decimal'] - for bigram in itertools.combinations(info["forward_words"], 2): - words = list(bigram) + [random.choice(math_words)] - # shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem' - external_links.append(" ".join(random.sample(words, len(words)))) - - return { - "title": info["topic"], # title of math problem - "topic": info["topic"], # title of problem topic - "subtopic": info["subtopic"], # title of problem subtopic - "content": info["problem"], # problem statement - "internal_links": [info["topic"], info["subtopic"]], # internal links - "external_links": external_links, - "tags": info["forward_words"], - "source": "Mathgenerator", - "extra": {"reward_type": info["reward_type"], "solution": info["solution"]}, - } + max_tries = 10 + for _ in range(max_tries): + info = mathgenerator.generate_context(name, **kwargs) + if info["reward_type"] != "float" or info["topic"] == "computer_science": + pass + else: + math_words = [ + "math", + "mathematics", + "mathematical", + "math problem", + "math technique", + ] + external_links = [] + # construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word + # binary_to_decimal -> ['binary to', 'to decimal'] + for bigram in itertools.combinations(info["forward_words"], 2): + words = list(bigram) + [random.choice(math_words)] + # shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem' + external_links.append(" ".join(random.sample(words, len(words)))) + + return { + "title": info["topic"], # title of math problem + "topic": info["topic"], # title of problem topic + "subtopic": info["subtopic"], # title of problem subtopic + "content": info["problem"], # problem statement + "internal_links": [info["topic"], info["subtopic"]], # internal links + "external_links": external_links, + "tags": info["forward_words"], + "source": "Mathgenerator", + "extra": {"reward_type": info["reward_type"], "solution": info["solution"]}, + } def search( self, name, selector: Selector, include: List = None, exclude: List = None From a41ff1a8e38ddfc077f3225572f47100eca6753f Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:41:45 -0400 Subject: [PATCH 17/18] Remove print statements --- prompting/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/prompting/__init__.py b/prompting/__init__.py index 9d7aee53..8033bdbe 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -54,8 +54,6 @@ registry_datasets = set( [dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets] ) -print(registry_datasets) -print(set(DATASETS.keys())) registry_missing_dataset = registry_datasets - set(DATASETS.keys()) registry_extra_dataset = set(DATASETS.keys()) - registry_datasets assert ( From f06994ad985746f9006146bad5f1ab8dd3b82826 Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:43:15 -0400 Subject: [PATCH 18/18] Update test_registry.py --- tests/test_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_registry.py b/tests/test_registry.py index 7cfe6f30..49c72fd2 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -3,7 +3,7 @@ from prompting.task_registry import TASK_REGISTRY -# TODO: Improve more detailed tasks. +# TODO: Create more detailed tests. def test_task_registry(): registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys()) registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys()) @@ -23,4 +23,4 @@ def test_task_registry_datasets(): ), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}" assert ( not registry_extra_dataset - ), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}" \ No newline at end of file + ), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}"