Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/registry #162

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions prompting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,29 @@
from . import agent
from . import conversation
from . import dendrite

from .llms import hf

from .tasks import TASKS
from .tools import DATASETS
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
from .task_registry import TASK_REGISTRY

# Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries.
registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys())
registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys())
assert (
not registry_missing_task
), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}"
assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}"

registry_datasets = set(
[dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets]
)
registry_missing_dataset = registry_datasets - set(DATASETS.keys())
registry_extra_dataset = set(DATASETS.keys()) - registry_datasets
assert (
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
not registry_missing_dataset
), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}"
assert (
not registry_extra_dataset
), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}"
117 changes: 47 additions & 70 deletions prompting/conversation.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,50 @@
from prompting.tasks import (
Task,
DebuggingTask,
QuestionAnsweringTask,
SummarizationTask,
MathTask,
DateQuestionAnsweringTask,
)
from prompting.tools import (
WikiDataset,
HFCodingDataset,
MathDataset,
WikiDateDataset,
)

import random
from transformers import Pipeline


def create_task(llm_pipeline: Pipeline, task_name: str, create_reference=True) -> Task:
wiki_based_tasks = ["summarization", "qa"]
coding_based_tasks = ["debugging"]
# TODO: Abstract dataset classes into common dynamic interface
if task_name in wiki_based_tasks:
dataset = WikiDataset()

elif task_name in coding_based_tasks:
dataset = HFCodingDataset()

elif task_name == "math":
dataset = MathDataset()

elif task_name == "date_qa":
dataset = WikiDateDataset()

if task_name == "summarization":
task = SummarizationTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "qa":
task = QuestionAnsweringTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "debugging":
task = DebuggingTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "math":
task = MathTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "date_qa":
task = DateQuestionAnsweringTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

from prompting.tasks import Task, TASKS
from prompting.tools import Selector, DATASETS
from prompting.task_registry import TASK_REGISTRY


def create_task(
llm_pipeline: Pipeline,
task_name: str,
create_reference: bool = True,
selector: Selector = random.choice,
) -> Task:
"""Create a task from the given task name and LLM pipeline.

Args:
llm_pipeline (Pipeline): Pipeline to use for text generation
task_name (str): Name of the task to create
create_reference (bool, optional): Generate text for task reference answer upon creation. Defaults to True.
selector (Selector, optional): Selector function to choose a dataset. Defaults to random.choice.

Raises:
ValueError: If task_name is not a valid alias for a task, or if the task is not a subclass of Task
ValueError: If no datasets are available for the given task
ValueError: If the dataset for the given task is not found

Returns:
Task: Task instance
"""

task = TASKS.get(task_name, None)
if task is None or not issubclass(task, Task):
raise ValueError(f"Task {task_name} not found")

dataset_choices = TASK_REGISTRY.get(task_name, [])
if len(dataset_choices) == 0:
raise ValueError(f"No datasets available for task {task_name}")

dataset_name = selector(dataset_choices)
dataset = DATASETS.get(dataset_name, None)
if dataset is None:
raise ValueError(f"Dataset {dataset_name} not found")
else:
raise ValueError(f"Task {task_name} not supported. Please choose a valid task")
dataset = dataset()

return task
return task(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)
9 changes: 9 additions & 0 deletions prompting/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO: Expand this to include extra information beyond just the task and dataset names
TASK_REGISTRY = {
"mock": ["mock"],
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
"summarization": ["wiki"],
"qa": ["wiki"],
"debugging": ["hf_coding"],
"math": ["math"],
"date_qa": ["wiki_date"],
}
1 change: 1 addition & 0 deletions prompting/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


TASKS = {
"mock": Task,
"qa": QuestionAnsweringTask,
"summarization": SummarizationTask,
"date_qa": DateQuestionAnsweringTask,
Expand Down
9 changes: 9 additions & 0 deletions prompting/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@
MathDataset,
)
from .selector import Selector

DATASETS = {
"mock": MockDataset,
"hf_coding": HFCodingDataset,
"wiki": WikiDataset,
#"stack_overflow": StackOverflowDataset,
"wiki_date": WikiDateDataset,
"math": MathDataset,
}
26 changes: 26 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from prompting.tasks import TASKS
from prompting.tools import DATASETS
from prompting.task_registry import TASK_REGISTRY


# TODO: Improve more detailed tasks.
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
def test_task_registry():
registry_missing_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys())
registry_extra_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys())
assert (
not registry_missing_task
), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}"
assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}"

def test_task_registry_datasets():
registry_datasets = set(
[dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets]
)
registry_missing_dataset = registry_datasets - set(DATASETS.keys())
registry_extra_dataset = set(DATASETS.keys()) - registry_datasets
assert (
not registry_missing_dataset
), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}"
assert (
not registry_extra_dataset
), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}"
Loading