Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLC-14] server: step toward Retrieval Augmented Generation (RAG) w/ … #8

Merged
merged 2 commits into from
Feb 28, 2024

Conversation

stockeh
Copy link
Contributor

@stockeh stockeh commented Feb 28, 2024

…indexing (load, split, store) retriever.

Overview of RAG: https://python.langchain.com/docs/use_cases/question_answering/

Implementation Details

Risks

  • Similarity score: empirical favoring toward max_marginal_relevance_search over similarity_search
  • Gemma embedding layer isn't optimized for text retrieval. Should we use an alternative text embedding model for this?

@stockeh stockeh requested a review from ParkerSm1th February 28, 2024 15:02
@stockeh stockeh self-assigned this Feb 28, 2024
@stockeh
Copy link
Contributor Author

stockeh commented Feb 28, 2024

Example usage of retriever:

from server.utils import load

from server.retriever.loader import directory_loader
from server.retriever.splitter import RecursiveCharacterTextSplitter
from server.retriever.vectorstore import Chroma, Embeddings


def main():
    model, tokenizer = load('mlx-community/quantized-gemma-7b-it')
    raw_docs = directory_loader(
        '/Users/stock/Library/Mobile Documents/iCloud~md~obsidian/Documents/main')
    print(len(raw_docs), len(raw_docs[0].page_content))
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1024, chunk_overlap=32, add_start_index=True
    )
    splits = text_splitter.split_documents(raw_docs)
    print(len(splits), len(splits[0].page_content), splits[0].metadata)
    db = Chroma.from_documents(
        documents=splits, embedding=Embeddings(model.model, tokenizer))
    print('-------------------')
    query = "What is a cascade neural network?"
    # docs = db.similarity_search(query)
    docs = db.max_marginal_relevance_search(query)
    print('>', query)
    for doc in docs:
        print(doc.page_content, doc.metadata, sep='\n')
        print('-------------------')


if __name__ == '__main__':
    main()

.vscode/settings.json Outdated Show resolved Hide resolved
@ParkerSm1th ParkerSm1th merged commit 47a0c2d into main Feb 28, 2024
1 check passed
@stockeh stockeh deleted the MLC-14 branch March 5, 2024 05:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants