diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/base_document_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/base_document_ingester.py similarity index 100% rename from personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/base_document_extractor.py rename to personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/base_document_ingester.py diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/fake_document_ingester.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/fake_document_ingester.py new file mode 100644 index 0000000..ead2e3f --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/fake_document_ingester.py @@ -0,0 +1,12 @@ +from collections.abc import Sequence + +from ..document_ingesters.document import Document +from .base_document_ingester import BaseDocumentIngester + + +class FakeDocumentIngester(BaseDocumentIngester): + def __init__(self, documents: Sequence[Document]): + self.documents = documents + + def get_documents(self) -> Sequence[Document]: + return self.documents diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/markdown_document_ingester.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/markdown_document_ingester.py index 994cc43..1a57aa9 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/markdown_document_ingester.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/markdown_document_ingester.py @@ -3,7 +3,7 @@ from tqdm import tqdm -from .base_document_extractor import BaseDocumentIngester +from .base_document_ingester import BaseDocumentIngester from .document import Document diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/test_markdown_document_extractor.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/test_markdown_document_ingester.py similarity index 100% rename from personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/test_markdown_document_extractor.py rename to personal_mnemonic_medium/v2/domain/prompt_source/document_ingesters/test_markdown_document_ingester.py diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py index 3cd4b7a..6711128 100644 --- a/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py +++ b/personal_mnemonic_medium/v2/domain/prompt_source/document_prompt_source.py @@ -2,7 +2,7 @@ from ..prompts.base_prompt import BasePrompt from .base_prompt_source import BasePromptSource -from .document_ingesters.base_document_extractor import ( +from .document_ingesters.base_document_ingester import ( BaseDocumentIngester, ) from .document_ingesters.document import Document @@ -11,8 +11,7 @@ ) -# TODO: https://github.com/MartinBernstorff/personal-mnemonic-medium/issues/296 tests: test DocumentPromptSource -class DocumentPromptSouce(BasePromptSource): +class DocumentPromptSource(BasePromptSource): def __init__( self, document_ingester: BaseDocumentIngester, diff --git a/personal_mnemonic_medium/v2/domain/prompt_source/test_document_prompt_source.py b/personal_mnemonic_medium/v2/domain/prompt_source/test_document_prompt_source.py new file mode 100644 index 0000000..5680e1c --- /dev/null +++ b/personal_mnemonic_medium/v2/domain/prompt_source/test_document_prompt_source.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from .document_ingesters.document import Document +from .document_ingesters.fake_document_ingester import ( + FakeDocumentIngester, +) +from .document_prompt_source import DocumentPromptSource +from .prompt_extractors.qa_prompt_extractor import QAPromptExtractor + + +def test_document_prompt_source(): + source = DocumentPromptSource( + document_ingester=FakeDocumentIngester( + [ + Document( + """Q. What is a test even? +A. Nothing""", + Path("test.md"), + ) + ] + ), + prompt_extractors=[ + QAPromptExtractor( + question_prefix="Q.", answer_prefix="A." + ) + ], + ) + prompts = source.get_prompts() + assert len(prompts) == 1