Skip to content

Commit

Permalink
feat: move async handling to outer loop (#71)
Browse files Browse the repository at this point in the history
- [ ] I have considered whether this PR needs review, and requested a
review if necessary.

Fixes issue #

# Notes for reviewers
Reviewers can skip X, but should pay attention to Y.
  • Loading branch information
MartinBernstorff authored Mar 23, 2024
1 parent 3784d73 commit 31465a9
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 42 deletions.
1 change: 0 additions & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ select = [
"COM",
"D417",
"E",
"ERA",
"F",
"I",
"ICN",
Expand Down
18 changes: 11 additions & 7 deletions memorymarker/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime as dt
import os
import time
Expand Down Expand Up @@ -105,15 +106,18 @@ def typer_cli(
typer.echo(f"Received {highlights.count()} new highlights")

typer.echo("Generating questions from highlights...")
questions = BaselinePipeline(
_name="gpt-4-basic",
openai_api_key=os.getenv(
"OPENAI_API_KEY", "No OPENAI_API_KEY environment variable set"
),
model="gpt-4-turbo-preview",
)(highlights)
questions = asyncio.run(
BaselinePipeline(
_name="gpt-4-basic",
openai_api_key=os.getenv(
"OPENAI_API_KEY", "No OPENAI_API_KEY environment variable set"
),
model="gpt-4-turbo-preview",
)(highlights)
)

typer.echo("Writing questions to markdown...")

for question in questions:
write_qa_prompt_to_md(save_dir=output_dir, highlight=question)

Expand Down
2 changes: 1 addition & 1 deletion memorymarker/question_generator/baseline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def _highlights_to_qa(
response = await asyncio.gather(*questions)
return response

def __call__(
async def __call__(
self, highlights: "Iter[ContextualizedHighlight]"
) -> "Iter[ReasonedHighlight]":
response: Sequence[QAPromptResponseModel] = asyncio.run(
Expand Down
2 changes: 1 addition & 1 deletion memorymarker/question_generator/highlight_to_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class HighlightToQuestion(Protocol):
def __call__(
async def __call__(
self, highlights: "Iter[ContextualizedHighlight]"
) -> "Iter[ReasonedHighlight]":
...
Expand Down
26 changes: 15 additions & 11 deletions memorymarker/question_generator/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence
Expand All @@ -8,7 +9,6 @@
ContextualizedHighlight,
)
from memorymarker.document_providers.omnivore import Omnivore
from memorymarker.question_generator.baseline_pipeline import BaselinePipeline
from memorymarker.question_generator.example_repo_airtable import (
AirtableExampleRepo,
PipelineHighlightIdentity,
Expand Down Expand Up @@ -70,7 +70,7 @@ def _select_highlights_from_omnivore(
return selected_highlights


if __name__ == "__main__":
async def main():
repository = AirtableExampleRepo()
selected_highlights = _select_highlights_from_omnivore(
search_terms={
Expand Down Expand Up @@ -99,16 +99,20 @@ def _select_highlights_from_omnivore(
"OPENAI_API_KEY", "No OPENAI_API_KEY environment variable set"
),
model="gpt-4-turbo-preview",
),
BaselinePipeline(
openai_api_key=os.getenv(
"OPENAI_API_KEY", "No OPENAI_API_KEY environment variable set"
),
model="gpt-4-turbo-preview",
_name="gpt-4-basic",
),
)
# BaselinePipeline(
# openai_api_key=os.getenv(
# "OPENAI_API_KEY", "No OPENAI_API_KEY environment variable set"
# ),
# model="gpt-4-turbo-preview",
# _name="gpt-4-basic",
# ),
],
).filter(lambda pair: pair.__hash__() not in old_example_hashes)

new_responses = Iter(run_pipelines(new_highlights)).flatten()
new_responses = await run_pipelines(new_highlights)
update_repository(new_responses, repository=repository)


if __name__ == "__main__":
asyncio.run(main())
42 changes: 31 additions & 11 deletions memorymarker/question_generator/pipeline_runner.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,45 @@
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Mapping, Sequence

from iterpy.iter import Iter

if TYPE_CHECKING:
from memorymarker.document_providers.contextualized_highlight import (
ContextualizedHighlight,
)
from memorymarker.question_generator.highlight_to_question import (
HighlightToQuestion,
)
from memorymarker.question_generator.main import HighlightWithPipeline
from memorymarker.question_generator.reasoned_highlight import ReasonedHighlight


def run_pipelines(
async def run_pipeline(
pipeline_name: str,
pipelinename2pipeline: Mapping[str, "HighlightToQuestion"],
highlights: Sequence["ContextualizedHighlight"],
) -> Iter["ReasonedHighlight"]:
pipeline = pipelinename2pipeline[pipeline_name]
prompts = pipeline(Iter(highlights))
return await prompts


async def run_pipelines(
pairs: "Iter[HighlightWithPipeline]",
) -> Sequence["ReasonedHighlight"]:
) -> Iter["ReasonedHighlight"]:
pipelinename2pipeline = {pair.pipeline.name: pair.pipeline for pair in pairs}

pipelines_with_highlights = pairs.groupby(lambda _: _.pipeline.name)

examples: Sequence[ReasonedHighlight] = []
for pipeline_name, pair in pipelines_with_highlights:
print(f"Creating examples for {pipeline_name}")
highlights = [pair.highlight for pair in pair]
pipeline = pipelinename2pipeline[pipeline_name]
prompts = pipeline(Iter(highlights))
examples.extend(prompts)
examples: Sequence["ReasonedHighlight"] = []
for pipeline_name, pairs_instance in pipelines_with_highlights:
print(f"Running pipeline {pipeline_name}")
for pair in pairs_instance:
examples.extend(
await run_pipeline(
pipeline_name=pipeline_name,
pipelinename2pipeline=pipelinename2pipeline,
highlights=[pair.highlight],
)
)

return examples
return Iter(examples)
13 changes: 3 additions & 10 deletions memorymarker/question_generator/reasoned_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,9 @@ async def _highlight_to_qa(
)
)

async def _highlights_to_qa(
self, highlights: Iter["ContextualizedHighlight"]
) -> Iter[ReasonedHighlight]:
async def __call__(
self, highlights: "Iter[ContextualizedHighlight]"
) -> "Iter[ReasonedHighlight]":
questions = [self._highlight_to_qa(highlight) for highlight in highlights]
response = await asyncio.gather(*questions)
return Iter(response).flatten()

def __call__(
self, highlights: "Iter[ContextualizedHighlight]"
) -> "Iter[ReasonedHighlight]":
response = asyncio.run(self._highlights_to_qa(highlights))

return response

0 comments on commit 31465a9

Please sign in to comment.