Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/batch forward #214

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand All @@ -25,17 +25,21 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest black
bash install.sh
bash install.sh

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

- name: Black
run: |
black .
uses: psf/black@stable
with:
options: "--check --verbose"
src: "."

- name: Test with pytest
run: |
# run tests in tests/ dir and only fail if there are failures or errors
Expand Down
2 changes: 1 addition & 1 deletion neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config=None):
mock=self.config.mock,
)

if abs(1-sum(self.config.neuron.task_p)) > 0.001:
if abs(1 - sum(self.config.neuron.task_p)) > 0.001:
raise ValueError("Task probabilities do not sum to 1.")

# Filter out tasks with 0 probability
Expand Down
10 changes: 5 additions & 5 deletions prompting/base/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __init__(self, config=None):
if self.config.mock:
self.wallet = bt.MockWallet(config=self.config)
self.subtensor = MockSubtensor(self.config.netuid, wallet=self.wallet)
self.metagraph = MockMetagraph(netuid=self.config.netuid, subtensor=self.subtensor)
self.metagraph = MockMetagraph(
netuid=self.config.netuid, subtensor=self.subtensor
)
else:
self.wallet = bt.wallet(config=self.config)
self.subtensor = bt.subtensor(config=self.config)
Expand All @@ -102,12 +104,10 @@ def __init__(self, config=None):
self.step = 0

@abstractmethod
def forward(self, synapse: bt.Synapse) -> bt.Synapse:
...
def forward(self, synapse: bt.Synapse) -> bt.Synapse: ...

@abstractmethod
def run(self):
...
def run(self): ...

def sync(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion prompting/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def create_task(
llm_pipeline=llm_pipeline,
context=dataset.next(),
create_reference=create_reference,
)
)
21 changes: 15 additions & 6 deletions prompting/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,22 @@
from prompting.utils.misc import async_log, serialize_exception_to_string
from dataclasses import dataclass


@async_log
async def generate_reference(agent):
async def generate_reference(agent):
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
return result
result = await loop.run_in_executor(
None, agent.task.generate_reference, agent.llm_pipeline
)
return result


@async_log
async def execute_dendrite_call(dendrite_call):
responses = await dendrite_call
return responses


@dataclass
class StreamResult:
synapse: StreamPromptingSynapse = None
Expand All @@ -55,7 +60,11 @@ async def process_response(uid: int, async_generator: Awaitable):
"""Process a single response asynchronously."""
try:
chunk = None # Initialize chunk with a default value
async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
async for (
chunk
) in (
async_generator
): # most important loop, as this is where we acquire the final synapse.
bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")

if chunk is not None:
Expand Down Expand Up @@ -217,8 +226,8 @@ async def run_step(

log_stream_results(stream_results)

all_synapses_results = [stream_result.synapse for stream_result in stream_results]
all_synapses_results = [stream_result.synapse for stream_result in stream_results]

# Encapsulate the responses in a response event (dataclass)
response_event = DendriteResponseEvent(
responses=all_synapses_results, uids=uids, timeout=timeout
Expand Down
9 changes: 3 additions & 6 deletions prompting/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

class BasePipeline(ABC):
@abstractmethod
def __call__(self, composed_prompt: str, **kwargs: dict) -> Any:
...
def __call__(self, composed_prompt: str, **kwargs: dict) -> Any: ...


class BaseLLM(ABC):
Expand All @@ -29,11 +28,9 @@ def query(
role: str = "user",
disregard_system_prompt: bool = False,
cleaner: CleanerPipeline = None,
) -> str:
...
) -> str: ...

def forward(self, messages: List[Dict[str, str]]):
...
def forward(self, messages: List[Dict[str, str]]): ...

def clean_response(self, cleaner: CleanerPipeline, response: str) -> str:
if cleaner is not None:
Expand Down
3 changes: 1 addition & 2 deletions prompting/rewards/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def __post_init__(self):
class BaseRewardModel(ABC):
@property
@abstractmethod
def name(self) -> str:
...
def name(self) -> str: ...

@abstractmethod
def __init__(self, **kwargs):
Expand Down
33 changes: 26 additions & 7 deletions prompting/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
from .tasks import Task, MockTask, SummarizationTask, QuestionAnsweringTask, DebuggingTask, MathTask, DateQuestionAnsweringTask, GenericInstructionTask
from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset, GenericInstructionDataset
from .tasks import (
Task,
MockTask,
SummarizationTask,
QuestionAnsweringTask,
DebuggingTask,
MathTask,
DateQuestionAnsweringTask,
GenericInstructionTask,
)
from .tools import (
MockDataset,
WikiDataset,
HFCodingDataset,
StackOverflowDataset,
MathDataset,
WikiDateDataset,
GenericInstructionDataset,
)

# TODO: Expand this to include extra information beyond just the task and dataset names
summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name]
qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name]
#debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name]
# debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name]
math_task, math_dataset = MathTask.name, [MathDataset.name]
date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name]
generic_instruction_task, generic_instruction_dataset = GenericInstructionTask.name, [GenericInstructionDataset.name]
generic_instruction_task, generic_instruction_dataset = GenericInstructionTask.name, [
GenericInstructionDataset.name
]

TASK_REGISTRY = {
summarization_task: summarization_dataset,
qa_task: qa_dataset,
#debugging_task: debugging_dataset,
# debugging_task: debugging_dataset,
math_task: math_dataset,
date_qa_task: date_qa_dataset,
generic_instruction_task: generic_instruction_dataset
}
generic_instruction_task: generic_instruction_dataset,
}
2 changes: 1 addition & 1 deletion prompting/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
QuestionAnsweringTask.name: QuestionAnsweringTask,
DateQuestionAnsweringTask.name: DateQuestionAnsweringTask,
SummarizationTask.name: SummarizationTask,
#DebuggingTask.name: DebuggingTask,
# DebuggingTask.name: DebuggingTask,
GenericInstructionTask.name: GenericInstructionTask,
MathTask.name: MathTask,
}
2 changes: 1 addition & 1 deletion prompting/tasks/date_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DateQuestionAnsweringTask(Task):
static_reference = True
static_query = True

def __init__(self, llm_pipeline, context, create_reference =True):
def __init__(self, llm_pipeline, context, create_reference=True):
self.context = context

self.query = (
Expand Down
4 changes: 2 additions & 2 deletions prompting/tasks/generic_instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class GenericInstructionTask(Task):
challenge_type = 'query'
challenge_type = "query"
name = "generic"
desc = "get help on answering a general instruction"
goal = "to get the answer to the following instruction"
Expand All @@ -38,7 +38,7 @@ def __init__(self, llm_pipeline, context, create_reference=True):
self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content)
self.query = self.generate_query(llm_pipeline)

self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(query = self.query)
self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(query=self.query)
if create_reference:
self.reference = self.generate_reference(llm_pipeline)

Expand Down
9 changes: 3 additions & 6 deletions prompting/tasks/mock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from prompting.tasks import Task


@dataclass
class MockTask(Task):
name = "mock"
Expand All @@ -18,12 +19,8 @@ class MockTask(Task):
def __init__(self, llm_pipeline, context, create_reference=True):
self.context = context

self.query = (
"How can I solve the following problem, "
+ context.content
+ "?"
)
self.query = "How can I solve the following problem, " + context.content + "?"
self.reference = "This is the reference answer"
self.topic = context.title
self.subtopic = context.topic
self.tags = context.tags
self.tags = context.tags
1 change: 0 additions & 1 deletion prompting/tasks/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from prompting.tasks import Task



# TODO: introduce criteria for the query and reference answer (length, layout, etc.) and make these arguments

# TODO: Also add a query system prompt and a query prompt template
Expand Down
4 changes: 3 additions & 1 deletion prompting/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

def make_system_prompt():
return CHATTENSOR_SYSTEM_PROMPT.format(date=time.strftime("%B %d, %Y"))


class TaskEvaluationType(Enum):
REWARD_STACK = "reward"
FILTER_STACK = "filter"
Expand Down Expand Up @@ -108,7 +110,7 @@ def generate_query(self, pipeline: BasePipeline, clean=True) -> str:
if not self.static_query:
bt.logging.info("🤖 Generating query...")
self.query = self.generate(
system=self.query_system_prompt, #Could possibly add the chattensor system prompt to query but I don't think it adds anything
system=self.query_system_prompt, # Could possibly add the chattensor system prompt to query but I don't think it adds anything
prompt=self.query_prompt,
pipeline=pipeline,
clean=clean,
Expand Down
9 changes: 3 additions & 6 deletions prompting/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from .selector import Selector

DATASETS = {
#HFCodingDataset.name: HFCodingDataset,
# HFCodingDataset.name: HFCodingDataset,
WikiDataset.name: WikiDataset,
#StackOverflowDataset.name: StackOverflowDataset,
# StackOverflowDataset.name: StackOverflowDataset,
MathDataset.name: MathDataset,
WikiDateDataset.name: WikiDateDataset,
GenericInstructionDataset.name: GenericInstructionDataset,
GenericInstructionDataset.name: GenericInstructionDataset,
}



7 changes: 4 additions & 3 deletions prompting/tools/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .context import Context
from .base import Dataset
from .context import Context, BatchContext
from .base import Dataset, BatchDataset
from .code import HFCodingDataset, StackOverflowDataset
from .math import MathDataset
from .mock import MockDataset
from .wiki import WikiDataset, WikiDateDataset
from .generic_instruction import GenericInstructionDataset
from .batch_wiki import BatchWikiDataset
from .generic_instruction import GenericInstructionDataset
Loading
Loading