Skip to content

Commit

Permalink
tests: DocumentPromptSource (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Dec 10, 2023
2 parents 741545a + 59d7e6a commit a6aff81
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from collections.abc import Sequence

from ..document_ingesters.document import Document
from .base_document_ingester import BaseDocumentIngester


class FakeDocumentIngester(BaseDocumentIngester):
def __init__(self, documents: Sequence[Document]):
self.documents = documents

def get_documents(self) -> Sequence[Document]:
return self.documents
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tqdm import tqdm

from .base_document_extractor import BaseDocumentIngester
from .base_document_ingester import BaseDocumentIngester
from .document import Document


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ..prompts.base_prompt import BasePrompt
from .base_prompt_source import BasePromptSource
from .document_ingesters.base_document_extractor import (
from .document_ingesters.base_document_ingester import (
BaseDocumentIngester,
)
from .document_ingesters.document import Document
Expand All @@ -11,8 +11,7 @@
)


# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/296 tests: test DocumentPromptSource
class DocumentPromptSouce(BasePromptSource):
class DocumentPromptSource(BasePromptSource):
def __init__(
self,
document_ingester: BaseDocumentIngester,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path

from .document_ingesters.document import Document
from .document_ingesters.fake_document_ingester import (
FakeDocumentIngester,
)
from .document_prompt_source import DocumentPromptSource
from .prompt_extractors.qa_prompt_extractor import QAPromptExtractor


def test_document_prompt_source():
source = DocumentPromptSource(
document_ingester=FakeDocumentIngester(
[
Document(
"""Q. What is a test even?
A. Nothing""",
Path("test.md"),
)
]
),
prompt_extractors=[
QAPromptExtractor(
question_prefix="Q.", answer_prefix="A."
)
],
)
prompts = source.get_prompts()
assert len(prompts) == 1

0 comments on commit a6aff81

Please sign in to comment.