Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/registry #162

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions prompting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,31 @@
from . import agent
from . import conversation
from . import dendrite

from .llms import hf

from .tasks import TASKS
from .tools import DATASETS
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
from .task_registry import TASK_REGISTRY

# Assert that all tasks have a dataset, and all tasks/datasets are in the TASKS and DATASETS dictionaries.
registry_missing_task = set(TASKS.keys()) - set(TASK_REGISTRY.keys())
registry_extra_task = set(TASK_REGISTRY.keys()) - set(TASKS.keys())
assert (
not registry_missing_task
), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}"
assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}"

registry_datasets = set(
[dataset for task, datasets in TASK_REGISTRY.items() for dataset in datasets]
)
print(registry_datasets)
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
print(set(DATASETS.keys()))
registry_missing_dataset = registry_datasets - set(DATASETS.keys())
registry_extra_dataset = set(DATASETS.keys()) - registry_datasets
assert (
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
not registry_missing_dataset
), f"Missing datasets in TASK_REGISTRY: {registry_missing_dataset}"
assert (
not registry_extra_dataset
), f"Extra datasets in TASK_REGISTRY: {registry_extra_dataset}"
117 changes: 47 additions & 70 deletions prompting/conversation.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,50 @@
from prompting.tasks import (
Task,
DebuggingTask,
QuestionAnsweringTask,
SummarizationTask,
MathTask,
DateQuestionAnsweringTask,
)
from prompting.tools import (
WikiDataset,
HFCodingDataset,
MathDataset,
WikiDateDataset,
)

import random
from transformers import Pipeline


def create_task(llm_pipeline: Pipeline, task_name: str, create_reference=True) -> Task:
wiki_based_tasks = ["summarization", "qa"]
coding_based_tasks = ["debugging"]
# TODO: Abstract dataset classes into common dynamic interface
if task_name in wiki_based_tasks:
dataset = WikiDataset()

elif task_name in coding_based_tasks:
dataset = HFCodingDataset()

elif task_name == "math":
dataset = MathDataset()

elif task_name == "date_qa":
dataset = WikiDateDataset()

if task_name == "summarization":
task = SummarizationTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "qa":
task = QuestionAnsweringTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "debugging":
task = DebuggingTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "math":
task = MathTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

elif task_name == "date_qa":
task = DateQuestionAnsweringTask(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)

from prompting.tasks import Task, TASKS
from prompting.tools import Selector, DATASETS
from prompting.task_registry import TASK_REGISTRY


def create_task(
llm_pipeline: Pipeline,
task_name: str,
create_reference: bool = True,
selector: Selector = random.choice,
) -> Task:
"""Create a task from the given task name and LLM pipeline.

Args:
llm_pipeline (Pipeline): Pipeline to use for text generation
task_name (str): Name of the task to create
create_reference (bool, optional): Generate text for task reference answer upon creation. Defaults to True.
selector (Selector, optional): Selector function to choose a dataset. Defaults to random.choice.

Raises:
ValueError: If task_name is not a valid alias for a task, or if the task is not a subclass of Task
ValueError: If no datasets are available for the given task
ValueError: If the dataset for the given task is not found

Returns:
Task: Task instance
"""

task = TASKS.get(task_name, None)
if task is None or not issubclass(task, Task):
raise ValueError(f"Task {task_name} not found")

dataset_choices = TASK_REGISTRY.get(task_name, None)
if len(dataset_choices) == 0:
raise ValueError(f"No datasets available for task {task_name}")

dataset_name = selector(dataset_choices)
dataset = DATASETS.get(dataset_name, None)
if dataset is None:
raise ValueError(f"Dataset {dataset_name} not found")
else:
raise ValueError(f"Task {task_name} not supported. Please choose a valid task")
dataset = dataset()

return task
return task(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)
19 changes: 19 additions & 0 deletions prompting/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .tasks import Task, MockTask, SummarizationTask, QuestionAnsweringTask, DebuggingTask, MathTask, DateQuestionAnsweringTask
from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset

# TODO: Expand this to include extra information beyond just the task and dataset names
mock_task, mock_dataset = MockTask.name, [MockDataset.name]
summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name]
qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name]
debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name]
math_task, math_dataset = MathTask.name, [MathDataset.name]
date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name]

TASK_REGISTRY = {
mock_task: mock_dataset,
summarization_task: summarization_dataset,
qa_task: qa_dataset,
debugging_task: debugging_dataset,
math_task: math_dataset,
date_qa_task: date_qa_dataset,
}
13 changes: 8 additions & 5 deletions prompting/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from .date_qa import DateQuestionAnsweringTask
from .generic_instruction import GenericInstructionTask
from .math import MathTask
from .mock import MockTask


TASKS = {
"qa": QuestionAnsweringTask,
"summarization": SummarizationTask,
"date_qa": DateQuestionAnsweringTask,
"debugging": DebuggingTask,
"math": MathTask,
MockTask.name: MockTask,
QuestionAnsweringTask.name: QuestionAnsweringTask,
DateQuestionAnsweringTask.name: DateQuestionAnsweringTask,
SummarizationTask.name: SummarizationTask,
DebuggingTask.name: DebuggingTask,
#GenericInstructionTask.name: GenericInstructionTask,
MathTask.name: MathTask,
}
2 changes: 1 addition & 1 deletion prompting/tasks/date_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@dataclass
class DateQuestionAnsweringTask(Task):
name = "date-based question answering"
name = "date_qa"
desc = "get help answering a specific date-based question"
goal = "to get the answer to the following date-based question"
reward_definition = [
Expand Down
29 changes: 29 additions & 0 deletions prompting/tasks/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from prompting.tasks import Task

@dataclass
class MockTask(Task):
name = "mock"
desc = "get help solving a math problem"
goal = "to get the answer to the following math question"

reward_definition = [
dict(name="float_diff", weight=1.0),
]
penalty_definition = []

static_reference = True
static_query = True

def __init__(self, llm_pipeline, context, create_reference=True):
bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
self.context = context

self.query = (
"How can I solve the following problem, "
+ context.content
+ "?"
)
self.reference = "This is the reference answer"
self.topic = context.title
self.subtopic = context.topic
self.tags = context.tags
2 changes: 1 addition & 1 deletion prompting/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

@dataclass
class QuestionAnsweringTask(Task):
name = "question-answering"
name = "qa"
desc = "get help on answering a question"
goal = "to get the answer to the following question"

Expand Down
12 changes: 12 additions & 0 deletions prompting/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,15 @@
MathDataset,
)
from .selector import Selector

DATASETS = {
"mock": MockDataset,
"hf_coding": HFCodingDataset,
"wiki": WikiDataset,
#"stack_overflow": StackOverflowDataset,
"wiki_date": WikiDateDataset,
"math": MathDataset,
}



bkb2135 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion prompting/tools/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class Dataset(ABC):
"""Base class for datasets."""

name = "dataset"
max_tries: int = 10

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions prompting/tools/datasets/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def filter_comments(code, language):

# TODO: why not define the chain_in, chain_out logic in the class itself?
class HFCodingDataset(Dataset):
name = "hf_coding"
def __init__(
self,
dataset_id="codeparrot/github-code",
Expand Down Expand Up @@ -615,6 +616,7 @@ def get_special_contents(self, code, language, remove_comments=True):


class StackOverflowDataset:
name = "stack_overflow"
def __init__(self):
# Stack Overflow API endpoint for a random article
self.url = "https://api.stackexchange.com/2.3/questions"
Expand Down
63 changes: 33 additions & 30 deletions prompting/tools/datasets/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


class MathDataset(Dataset):
name = 'math'
topics_list = mathgenerator.getGenList()

def __init__(self, seed=None):
Expand All @@ -56,36 +57,38 @@ def get(
Dict: _description_
"""
bt.logging.info(f"Getting math problem {name!r}")
info = mathgenerator.generate_context(name, **kwargs)
if info["reward_type"] != "float" or info["topic"] == "computer_science":
return None

math_words = [
"math",
"mathematics",
"mathematical",
"math problem",
"math technique",
]
external_links = []
# construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word
# binary_to_decimal -> ['binary to', 'to decimal']
for bigram in itertools.combinations(info["forward_words"], 2):
words = list(bigram) + [random.choice(math_words)]
# shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem'
external_links.append(" ".join(random.sample(words, len(words))))

return {
"title": info["topic"], # title of math problem
"topic": info["topic"], # title of problem topic
"subtopic": info["subtopic"], # title of problem subtopic
"content": info["problem"], # problem statement
"internal_links": [info["topic"], info["subtopic"]], # internal links
"external_links": external_links,
"tags": info["forward_words"],
"source": "Mathgenerator",
"extra": {"reward_type": info["reward_type"], "solution": info["solution"]},
}
max_tries = 10
for _ in range(max_tries):
info = mathgenerator.generate_context(name, **kwargs)
if info["reward_type"] != "float" or info["topic"] == "computer_science":
pass
else:
math_words = [
"math",
"mathematics",
"mathematical",
"math problem",
"math technique",
]
external_links = []
# construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word
# binary_to_decimal -> ['binary to', 'to decimal']
for bigram in itertools.combinations(info["forward_words"], 2):
words = list(bigram) + [random.choice(math_words)]
# shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem'
external_links.append(" ".join(random.sample(words, len(words))))

return {
"title": info["topic"], # title of math problem
"topic": info["topic"], # title of problem topic
"subtopic": info["subtopic"], # title of problem subtopic
"content": info["problem"], # problem statement
"internal_links": [info["topic"], info["subtopic"]], # internal links
"external_links": external_links,
"tags": info["forward_words"],
"source": "Mathgenerator",
"extra": {"reward_type": info["reward_type"], "solution": info["solution"]},
}

def search(
self, name, selector: Selector, include: List = None, exclude: List = None
Expand Down
1 change: 1 addition & 0 deletions prompting/tools/datasets/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class MockDataset(Dataset):
name = "mock"
def get(self, name, exclude=None, selector=None):
return {
"title": name,
Expand Down
3 changes: 2 additions & 1 deletion prompting/tools/datasets/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def filter_categories(categories, exclude=None, include=None):

class WikiDataset(Dataset):
"""Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections."""

name = "wiki"
EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links")
EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1")

Expand Down Expand Up @@ -240,6 +240,7 @@ def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Di


class WikiDateDataset(Dataset):
name = "wiki_date"
INCLUDE_HEADERS = ("Events", "Births", "Deaths")
MONTHS = (
"January",
Expand Down
Loading
Loading