diff --git a/CHANGELOG.md b/CHANGELOG.md index a8b52e86..a6d0ae1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,18 +6,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +### Fixed + +## [0.16.0] + ### Added - Added pretty-printing via `PT.pprint` that does NOT depend on Markdown and splits text to adjust to the width of the output terminal. + It is useful in notebooks to add newlines. - Added support annotations for RAGTools (see `?RAGTools.Experimental.annotate_support` for more information) to highlight which parts of the generated answer come from the provided context versus the model's knowledge base. It's useful for transparency and debugging, especially in the context of AI-generated content. You can experience it if you run the output of `airag` through pretty printing (`PT.pprint`). - Added utility `distance_longest_common_subsequence` to find the normalized distance between two strings (or a vector of strings). Always returns a number between 0-1, where 0 means the strings are identical and 1 means they are completely different. It's useful for comparing the similarity between the context provided to the model and the generated answer. - Added a new documentation section "Extra Tools" to highlight key functionality in various modules, eg, the available text utilities, which were previously hard to discover. - Extended documentation FAQ with tips on tackling rate limits and other common issues with OpenAI API. - Extended documentation with all available prompt templates. See section "Prompt Templates" in the documentation. +- Added new RAG interface underneath `airag` in `PromptingTools.RAGTools.Experimental`. Each step now has a dedicated function and a type that can be customized to achieve arbitrary logic (via defining methods for your own types). `airag` is split into two main steps: `retrieve` and `generate!`. You can use them separately or together. See `?airag` for more information. ### Updated - Renamed `split_by_length` text splitter to `recursive_splitter` to make it easier to discover and understand its purpose. `split_by_length` is still available as a deprecated alias. ### Fixed +- Fixed a bug where `LOCAL_SERVER` default value was not getting picked up. Now, it defaults to `http://localhost:8000` if not set in the preferences, which is the address of the server started by Llama.jl. +- Fixed a bug in multi-line code annotation, which was assigning too optimistic scores to the generated code. Now the score of the chunk is the length-weighted score of the "top" source chunk divided by the full length of score tokens (much more robust and demanding). ## [0.15.0] diff --git a/docs/src/examples/building_RAG.md b/docs/src/examples/building_RAG.md index d08b35aa..56487728 100644 --- a/docs/src/examples/building_RAG.md +++ b/docs/src/examples/building_RAG.md @@ -57,7 +57,7 @@ What does it do? - [OPTIONAL] extracts any potential tags/filters from the question and applies them to filter down the potential candidates (use `extract_metadata=true` in `build_index`, you can also provide some filters explicitly via `tag_filter`) - [OPTIONAL] re-ranks the candidate chunks (define and provide your own `rerank_strategy`, eg Cohere ReRank API) - build a context from the closest chunks (use `chunks_window_margin` to tweak if we include preceding and succeeding chunks as well, see `?build_context` for more details) -- generate an answer from the closest chunks (use `return_details=true` to see under the hood and debug your application) +- generate an answer from the closest chunks (use `return_all=true` to see under the hood and debug your application) You should save the index for later to avoid re-embedding / re-extracting the document chunks! @@ -124,7 +124,7 @@ Let's evaluate this QA item with a "judge model" (often GPT-4 is used as a judge ````julia # Note: that we used the same question, but generated a different context and answer via `airag` -msg, ctx = airag(index; evals[1].question, return_details = true); +ctx = airag(index; evals[1].question, return_all = true); # ctx is a RAGContext object that keeps all intermediate states of the RAG pipeline for easy evaluation judged = aiextract(:RAGJudgeAnswerFromContext; ctx.context, @@ -173,7 +173,7 @@ Let's run each question & answer through our eval loop in async (we do it only f ````julia results = asyncmap(evals[1:10]) do qa_item # Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t) - msg, ctx = airag(index; qa_item.question, return_details = true, + msg, ctx = airag(index; qa_item.question, return_all = true, top_k = 3, verbose = false, model_judge = "gpt4t") # Evaluate the response # Note: you can log key parameters for easier analysis later diff --git a/docs/src/extra_tools/rag_tools_intro.md b/docs/src/extra_tools/rag_tools_intro.md index d0225a4d..ce1d70e1 100644 --- a/docs/src/extra_tools/rag_tools_intro.md +++ b/docs/src/extra_tools/rag_tools_intro.md @@ -6,6 +6,10 @@ CurrentModule = PromptingTools.Experimental.RAGTools `RAGTools` is an experimental module that provides a set of utilities for building Retrieval-Augmented Generation (RAG) applications, ie, applications that generate answers by combining knowledge of the underlying AI model with the information from the user's knowledge base. +It is designed to be powerful and flexible, allowing you to build RAG applications with minimal effort. Extend any step of the pipeline with your own custom code (see the [RAG Interface](@ref) section), or use the provided defaults to get started quickly. + +Once the API stabilizes (near term), we hope to carve it out into a separate package. + Import the module as follows: ```julia @@ -22,13 +26,213 @@ const RT = PromptingTools.Experimental.RAGTools The main functions to be aware of are: - `build_index` to build a RAG index from a list of documents (type `ChunkIndex`) - `airag` to generate answers using the RAG model on top of the `index` built above -- `annotate_support` to highlight which parts of the RAG answer are supported by the documents in the index vs which are generated by the model + - `retrieve` to retrieve relevant chunks from the index for a given question + - `generate!` to generate an answer from the retrieved chunks +- `annotate_support` to highlight which parts of the RAG answer are supported by the documents in the index vs which are generated by the model, it is applied automatically if you use pretty printing with `pprint` (eg, `pprint(result)`) - `build_qa_evals` to build a set of question-answer pairs for evaluation of the RAG model from your corpus -See example `examples/building_RAG.jl` for an end-to-end example of how to use these tools. - The hope is to provide a modular and easily extensible set of tools for building RAG applications in Julia. Feel free to open an issue or ask in the `#generative-ai` channel in the JuliaLang Slack if you have a specific need. +## Examples + +Let's build an index, we need to provide a starter list of documents: +```julia +sentences = [ + "Find the most comprehensive guide on Julia programming language for beginners published in 2023.", + "Search for the latest advancements in quantum computing using Julia language.", + "How to implement machine learning algorithms in Julia with examples.", + "Looking for performance comparison between Julia, Python, and R for data analysis.", + "Find Julia language tutorials focusing on high-performance scientific computing.", + "Search for the top Julia language packages for data visualization and their documentation.", + "How to set up a Julia development environment on Windows 10.", + "Discover the best practices for parallel computing in Julia.", + "Search for case studies of large-scale data processing using Julia.", + "Find comprehensive resources for mastering metaprogramming in Julia.", + "Looking for articles on the advantages of using Julia for statistical modeling.", + "How to contribute to the Julia open-source community: A step-by-step guide.", + "Find the comparison of numerical accuracy between Julia and MATLAB.", + "Looking for the latest Julia language updates and their impact on AI research.", + "How to efficiently handle big data with Julia: Techniques and libraries.", + "Discover how Julia integrates with other programming languages and tools.", + "Search for Julia-based frameworks for developing web applications.", + "Find tutorials on creating interactive dashboards with Julia.", + "How to use Julia for natural language processing and text analysis.", + "Discover the role of Julia in the future of computational finance and econometrics." +] +``` + +Let's index these "documents": + +```julia +index = build_index(sentences; chunker_kwargs=(; sources=map(i -> "Doc$i", 1:length(sentences)))) +``` + +This would be equivalent to the following `index = build_index(SimpleIndexer(), sentences)` which dispatches to the default implementation of each step via the `SimpleIndexer` struct. We provide these default implementations for the main functions as an optional argument - no need to provide them if you're running the default pipeline. + +Notice that we have provided a `chunker_kwargs` argument to the `build_index` function. These will be kwargs passed to `chunker` step. + +Now let's generate an answer to a question. + +1. Run end-to-end RAG (retrieve + generate!), return `AIMessage` +```julia +question = "What are the best practices for parallel computing in Julia?" + +msg = airag(index; question) # short for airag(RAGConfig(), index; question) +## Output: +## [ Info: Done with RAG. Total cost: \$0.0 +## AIMessage("Some best practices for parallel computing in Julia include us... +``` + +2. Explore what's happening under the hood by changing the return type - `RAGResult` contains all intermediate steps. +```julia +result = airag(index; question, return_all=true) +## RAGResult +## question: String "What are the best practices for parallel computing in Julia?" +## rephrased_questions: Array{String}((1,)) +## answer: SubString{String} +## final_answer: SubString{String} +## context: Array{String}((5,)) +## sources: Array{String}((5,)) +## emb_candidates: CandidateChunks{Int64, Float32} +## tag_candidates: CandidateChunks{Int64, Float32} +## filtered_candidates: CandidateChunks{Int64, Float32} +## reranked_candidates: CandidateChunks{Int64, Float32} +## conversations: Dict{Symbol, Vector{<:PromptingTools.AbstractMessage}} +``` + +You can still get the message from the result, see `result.conversations[:final_answer]` (the dictionary keys correspond to the function names of those steps). + + +3. If you need to customize it, break the pipeline into its sub-steps: retrieve and generate - RAGResult serves as the intermediate result. +```julia +# Retrieve which chunks are relevant to the question +result = retrieve(index, question) +# Generate an answer +result = generate!(index, result) +``` + +You can leverage a pretty-printing system with `pprint` where we automatically annotate the support of the answer by the chunks we provided to the model. +It is configurable and you can select only some of its functions (eg, scores, sources). + +```julia +pprint(result) +``` + +You'll see the following in REPL but with COLOR highlighting in the terminal. + +```plaintext +-------------------- +QUESTION(s) +-------------------- +- What are the best practices for parallel computing in Julia? + +-------------------- +ANSWER +-------------------- +Some of the best practices for parallel computing in Julia include:[1,0.7] +- Using [3,0.4]`@threads` for simple parallelism[1,0.34] +- Utilizing `Distributed` module for more complex parallel tasks[1,0.19] +- Avoiding excessive memory allocation +- Considering task granularity for efficient workload distribution + +-------------------- +SOURCES +-------------------- +1. Doc8 +2. Doc15 +3. Doc5 +4. Doc2 +5. Doc9 +``` + +**How to read the output** +- Color legend: + - No color: High match with the context, can be trusted more + - Blue: Partial match against some words in the context, investigate + - Magenta (Red): No match with the context, fully generated by the model +- Square brackets: The best matching context ID + Match score of the chunk (eg, `[3,0.4]` means the highest support for the sentence is from the context chunk number 3 with a 40% match). + +Want more? + +See `examples/building_RAG.jl` for one more example. + +## RAG Interface + +### System Overview + +This system is designed for information retrieval and response generation, structured in three main phases: +- Preparation, when you create an instance of `AbstractIndex` +- Retrieval, when you surface the top most relevant chunks/items in the `index` and return `AbstractRAGResult`, which contains the references to the chunks (`AbstractCandidateChunks`) +- Generation, when you generate an answer based on the context built from the retrieved chunks, return either `AIMessage` or `AbstractRAGResult` + +The system is designed to be hackable and extensible at almost every entry point. +If you want to customize the behavior of any step, you can do so by defining a new type and defining a new method for the step you're changing, eg, +```julia +struct MyReranker <: AbstractReranker end +RT.rerank(::MyReranker, index, candidates) = ... +``` +And then you'd ask for the `retrive` step to use your custom `MyReranker`, eg, `retrieve(....; reranker = MyReranker())` (or customize the main dispatching `AbstractRetriever` struct). + +The overarching principles are: +- Always dispatch / customize the behavior by defining a new `Struct` and the corresponding method for the existing functions (eg, `rerank` function for the re-ranking step). +- Custom types are provided as the first argument (the high-level functions will work without them as we provide some defaults). +- Custom types do NOT have any internal fields or DATA (with the exception of managing sub-steps of the pipeline like `AbstractRetriever` or `RAGConfig`). +- Additional data should be passed around as keyword arguments (eg, `chunker_kwargs` in `build_index` to pass data to the chunking step). The intention was to have some clearly documented default values in the docstrings of each step + to have the various options all in one place. + +### RAG Diagram + +The main functions are: + +`build_index`: +- signature: `(indexer::AbstractIndexBuilder, files_or_docs::Vector{<:AbstractString}) -> AbstractChunkIndex` +- flow: `get_chunks` -> `get_embeddings` -> `get_tags` -> `build_tags` +- dispatch types: `AbstractIndexBuilder`, `AbstractChunker`, `AbstractEmbedder`, `AbstractTagger` + +`airag`: +- signature: `(cfg::AbstractRAGConfig, index::AbstractChunkIndex; question::AbstractString)` -> `AIMessage` or `AbstractRAGResult` +- flow: `retrieve` -> `generate!` +- dispatch types: `AbstractRAGConfig`, `AbstractRetriever`, `AbstractGenerator` + +`retrieve`: +- signature: `(retriever::AbstractRetriever, index::AbstractChunkIndex, question::AbstractString) -> AbstractRAGResult` +- flow: `rephrase` -> `get_embeddings` -> `find_closest` -> `get_tags` -> `find_tags` -> `rerank` +- dispatch types: `AbstractRAGConfig`, `AbstractRephraser`, `AbstractEmbedder`, `AbstractSimilarityFinder`, `AbstractTagger`, `AbstractTagFilter`, `AbstractReranker` + +`generate!`: +- signature: `(generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult)` -> `AIMessage` or `AbstractRAGResult` +- flow: `build_context!` -> `answer!` -> `refine!` -> `postprocess!` +- dispatch types: `AbstractGenerator`, `AbstractContextBuilder`, `AbstractAnswerer`, `AbstractRefiner`, `AbstractPostprocessor` + +To discover the currently available implementations, use `subtypes` function, eg, `subtypes(AbstractReranker)`. + +### Deepdive + +**Preparation Phase:** +- Begins with `build_index`, which creates a user-defined index type from an abstract chunk index using specified dels and function strategies. +- `get_chunks` then divides the indexed data into manageable pieces based on a chunking strategy. +- `get_embeddings` generates embeddings for each chunk using an embedding strategy to facilitate similarity arches. +- Finally, `get_tags` extracts relevant metadata from each chunk, enabling tag-based filtering (hybrid search index). If there are `tags` available, `build_tags` is called to build the corresponding sparse matrix for filtering with tags. + +**Retrieval Phase:** +- The `retrieve` step is intended to find the most relevant chunks in the `index`. +- `rephrase` is called first, if we want to rephrase the query (methods like `HyDE` can improve retrieval quite a bit)! +- `get_embeddings` generates embeddings for the original + rephrased query +- `find_closest` looks up the most relevant candidates (`CandidateChunks`) using a similarity search strategy. +- `get_tags` extracts the potential tags (can be provided as part of the `airag` call, eg, when we want to use only some small part of the indexed chunks) +- `find_tags` filters the candidates to strictly match _at least one_ of the tags (if provided) +- `rerank` is called to rerank the candidates based on the reranking strategy (ie, to improve the ordering of the chunks in context). + +**Generation Phase:** +- The `generate` step is intended to generate a response based on the retrieved chunks, provided via `AbstractRAGResult` (eg, `RAGResult`). +- `build_context!` constructs the context for response generation based on a context strategy and applies the necessary formatting +- `answer!` generates the response based on the context and the query +- `refine!` is called to refine the response (optional, defaults to passthrough) +- `postprocessing!` is available for any final touches to the response or to potentially save or format the results (eg, automatically save to the disk) + +Note that all generation steps are mutating the `RAGResult` object. + +See more details and corresponding functions and types in `src/Experimental/RAGTools/rag_interface.jl`. + ## References ```@docs; canonical=false diff --git a/examples/building_RAG.jl b/examples/building_RAG.jl index f2b8582e..2b4579c4 100644 --- a/examples/building_RAG.jl +++ b/examples/building_RAG.jl @@ -24,8 +24,8 @@ files = [ joinpath("examples", "data", "database_style_joins.txt"), joinpath("examples", "data", "what_is_dataframes.txt") ] -## Build an index of chunks, embed them, and create a lookup index of metadata/tags for each chunk -index = build_index(files; extract_metadata = false) +## Build an index of chunks and embed them +index = build_index(files) # Let's ask a question ## Embeds the question, finds the closest chunks in the index, and generates an answer from the closest chunks @@ -37,12 +37,17 @@ answer = airag(index; question = "I like dplyr, what is the equivalent in Julia? # - `build_index` will chunk the documents into smaller pieces, embed them into numbers (to be able to judge the similarity of chunks) and, optionally, create a lookup index of metadata/tags for each chunk) # - `index` is the result of this step and it holds your chunks, embeddings, and other metadata! Just show it :) # - `airag` will -# - embed your question -# - find the closest chunks in the index (use parameters `top_k` and `minimum_similarity` to tweak the "relevant" chunks) -# - [OPTIONAL] extracts any potential tags/filters from the question and applies them to filter down the potential candidates (use `extract_metadata=true` in `build_index`, you can also provide some filters explicitly via `tag_filter`) -# - [OPTIONAL] re-ranks the candidate chunks (define and provide your own `rerank_strategy`, eg Cohere ReRank API) -# - build a context from the closest chunks (use `chunks_window_margin` to tweak if we include preceding and succeeding chunks as well, see `?build_context` for more details) -# - generate an answer from the closest chunks (use `return_details=true` to see under the hood and debug your application) +# - retrieve the best chunks from your index (based on the similarity of the question to the chunks) +# - rephrase the question into a more "searchable" form +# - embed your question +# - find the closest chunks in the index (use parameters `top_k` and `minimum_similarity` to tweak the "relevant" chunks) +# - [OPTIONAL] extract any potential tags/filters from the question and applies them to filter down the potential candidates (use `extract_metadata=true` in `build_index`, you can also provide some filters explicitly via `tag_filter`) +# - [OPTIONAL] re-rank the candidate chunks (define and provide your own `rerank_strategy`, eg Cohere ReRank API) +# - generate an answer from the closest chunks (use `return_all=true` to see under the hood and debug your application) +# - build a context from the closest chunks (use `chunks_window_margin` to tweak if we include preceding and succeeding chunks as well, see `?build_context` for more details) +# - answer the question with LLM +# - [OPTIONAL] refine the answer (with the same or new context) +# # You should save the index for later to avoid re-embedding / re-extracting the document chunks! serialize("examples/index.jls", index) @@ -80,13 +85,13 @@ evals[1] # Let's evaluate this QA item with a "judge model" (often GPT-4 is used as a judge). ## Note: that we used the same question, but generated a different context and answer via `airag` -msg, ctx = airag(index; evals[1].question, return_details = true); +result = airag(index; evals[1].question, return_all = true); ## ctx is a RAGContext object that keeps all intermediate states of the RAG pipeline for easy evaluation judged = aiextract(:RAGJudgeAnswerFromContext; - ctx.context, - ctx.question, - ctx.answer, + result.context, + result.question, + result.final_answer, return_type = RT.JudgeAllScores) judged.content ## Dict{Symbol, Any} with 7 entries: @@ -110,11 +115,11 @@ x = run_qa_evals(evals[10], ctx; results = asyncmap(evals[1:10]) do qa_item ## Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t) - msg, ctx = airag(index; qa_item.question, return_details = true, + result = airag(index; qa_item.question, return_all = true, top_k = 3, verbose = false, model_judge = "gpt4t") ## Evaluate the response ## Note: you can log key parameters for easier analysis later - run_qa_evals(qa_item, ctx; parameters_dict = Dict(:top_k => 3), verbose = false) + run_qa_evals(qa_item, result; parameters_dict = Dict(:top_k => 3), verbose = false) end ## Note that the "failed" evals can show as "nothing", so make sure to handle them. results = filter(x -> !isnothing(x.answer_score), results); @@ -136,8 +141,8 @@ first(df, 5) # # What would we do next? # - Review your evaluation golden data set and keep only the good items -# - Play with the chunk sizes (max_length in build_index) and see how it affects the quality -# - Explore using metadata/key filters (`extract_metadata=true` in build_index) +# - Play with the chunk sizes (max_length in `build_index.chunker`) and see how it affects the quality +# - Explore using metadata/key filters (`tagger` step in `build_index`) # - Add filtering for semantic similarity (embedding distance) to make sure we don't pick up irrelevant chunks in the context # - Use multiple indices or a hybrid index (add a simple BM25 lookup from TextAnalysis.jl) # - Data processing is the most important step - properly parsed and split text could make wonders diff --git a/ext/RAGToolsExperimentalExt.jl b/ext/RAGToolsExperimentalExt.jl index 9b095a44..dc0492e9 100644 --- a/ext/RAGToolsExperimentalExt.jl +++ b/ext/RAGToolsExperimentalExt.jl @@ -5,14 +5,23 @@ using LinearAlgebra: normalize const PT = PromptingTools using PromptingTools.Experimental.RAGTools +const RT = PromptingTools.Experimental.RAGTools # forward to LinearAlgebra.normalize -PromptingTools.Experimental.RAGTools._normalize(arr::AbstractArray) = normalize(arr) +RT._normalize(arr::AbstractArray) = normalize(arr) -# "Builds a sparse matrix of tags and a vocabulary from the given vector of chunk metadata. Requires SparseArrays.jl to be loaded." -function PromptingTools.Experimental.RAGTools.build_tags(chunk_metadata::Vector{ - Vector{String}, -}) +""" + RT.build_tags( + tagger::RT.AbstractTagger, chunk_metadata::AbstractVector{ + <:AbstractVector{String}, + }) + +Builds a sparse matrix of tags and a vocabulary from the given vector of chunk metadata. +""" +function RT.build_tags( + tagger::RT.AbstractTagger, chunk_metadata::AbstractVector{ + <:AbstractVector{String}, + }) tags_vocab_ = vcat(chunk_metadata...) |> unique |> sort tags_vocab_index = Dict{String, Int}(t => i for (i, t) in enumerate(tags_vocab_)) Is, Js = Int[], Int[] @@ -31,4 +40,4 @@ function PromptingTools.Experimental.RAGTools.build_tags(chunk_metadata::Vector{ return tags_, tags_vocab_ end -end +end # end of module diff --git a/src/Experimental/RAGTools/RAGTools.jl b/src/Experimental/RAGTools/RAGTools.jl index 094eebf1..d259dbfd 100644 --- a/src/Experimental/RAGTools/RAGTools.jl +++ b/src/Experimental/RAGTools/RAGTools.jl @@ -10,12 +10,15 @@ This module is experimental and may change at any time. It is intended to be mov module RAGTools using PromptingTools -using PromptingTools: pprint +using PromptingTools: pprint, AbstractMessage using HTTP, JSON3 using AbstractTrees using AbstractTrees: PreOrderDFS const PT = PromptingTools +# reexport +export pprint + ## export trigrams, trigrams_hashed, text_to_trigrams, text_to_trigrams_hashed ## export STOPWORDS, tokenize, split_into_code_and_sentences include("utils.jl") @@ -25,19 +28,22 @@ include("api_services.jl") include("rag_interface.jl") -export ChunkIndex, CandidateChunks # MultiIndex +export ChunkIndex, CandidateChunks, RAGResult +# export MultiIndex # not ready yet include("types.jl") -export build_index, build_tags +export build_index, get_chunks, get_embeddings, get_tags include("preparation.jl") -export find_closest, find_tags, rerank +export retrieve, SimpleRetriever, AdvancedRetriever +export find_closest, find_tags, rerank, rephrase include("retrieval.jl") -export airag, build_context +export airag, build_context!, generate!, refine!, answer!, postprocess! +export SimpleGenerator, AdvancedGenerator, RAGConfig include("generation.jl") -export annotate_support +export annotate_support, TrigramAnnotater include("annotation.jl") export build_qa_evals, run_qa_evals diff --git a/src/Experimental/RAGTools/annotation.jl b/src/Experimental/RAGTools/annotation.jl index 966cc739..2547ac15 100644 --- a/src/Experimental/RAGTools/annotation.jl +++ b/src/Experimental/RAGTools/annotation.jl @@ -323,21 +323,25 @@ function add_node_metadata!(annotater::TrigramAnnotater, i = 1 source_scores = Dict{Int, Float64}() source_lengths = Dict{Int, Int}() + non_source_length = 0 previous_group_id = children[1].group_id while i <= length(children) child = children[i] # Check if group_id has changed or it's the last child to record source if (child.group_id != previous_group_id) && !isempty(source_scores) # Add a metadata node for the previous group - src, score_sum = maximum(source_scores) - score = score_sum / source_lengths[src] # average score, length weighted + score_sum, src = findmax(source_scores) + # average score weighted by the length of ALL text + # the goal is to show the match of top source across all text, not just the tokens that matched - it could be misleading + # the goal is "how confident are we that this source is the best match for the whole text" + score = score_sum / (sum(values(source_lengths)) + non_source_length) metadata_content = string("[", add_sources ? src : "", add_sources ? "," : "", add_scores ? round(score, digits = 2) : "", "]") ## Check if there is any content, then add it - if length(metadata_content) > 2 + if length(metadata_content) > 3 src_node = AnnotatedNode(; parent = root, group_id = previous_group_id, content = metadata_content) insert!(children, i, src_node) @@ -356,15 +360,19 @@ function add_node_metadata!(annotater::TrigramAnnotater, len = length(child.content) source_scores[src] = get(source_scores, src, 0) + child.score * len source_lengths[src] = get(source_lengths, src, 0) + len + elseif !isnothing(child.score) + ## track the low match tokens without any source allocated + non_source_length += length(child.content) end + # Next round i += 1 end ## Run for the last item if !isempty(source_scores) # Add a metadata node for the previous group - src, score_sum = maximum(source_scores) - score = score_sum / source_lengths[src] # average score, length weighted + score_sum, src = findmax(source_scores) + score = score_sum / (sum(values(source_lengths)) + non_source_length) metadata_content = string("[", add_sources ? src : "", add_sources ? "," : "", @@ -441,6 +449,7 @@ function annotate_support(annotater::TrigramAnnotater, answer::AbstractString, min_source_score::Float64 = 0.25, add_sources::Bool = true, add_scores::Bool = true, kwargs...) + @assert !isempty(context) "Context cannot be empty" ## use hashed trigrams by default (more efficient for larger sequences) if hashed trigram_func = trigrams_hashed @@ -482,13 +491,34 @@ function annotate_support(annotater::TrigramAnnotater, answer::AbstractString, end # Dispatch for RAGResult +""" + annotate_support( + annotater::TrigramAnnotater, result::AbstractRAGResult; min_score::Float64 = 0.5, + skip_trigrams::Bool = true, hashed::Bool = true, + min_source_score::Float64 = 0.25, + add_sources::Bool = true, + add_scores::Bool = true, kwargs...) + +Dispatch for `annotate_support` for `AbstractRAGResult` type. It extracts the `final_answer` and `context` from the `result` and calls `annotate_support` with them. + +See `annotate_support` for more details. + +# Example +```julia +res = RAGResult(; question = "", final_answer = "This is a test.", + context = ["Test context.", "Completely different"]) +annotated_root = annotate_support(annotater, res) +PT.pprint(annotated_root) +``` +""" function annotate_support( annotater::TrigramAnnotater, result::AbstractRAGResult; min_score::Float64 = 0.5, - skip_trigrams::Bool = true, hashed::Bool = false, + skip_trigrams::Bool = true, hashed::Bool = true, min_source_score::Float64 = 0.25, add_sources::Bool = true, add_scores::Bool = true, kwargs...) + final_answer = isnothing(result.final_answer) ? result.answer : result.final_answer return annotate_support( - annotater, result.refined_answer, result.context; min_score, skip_trigrams, + annotater, final_answer, result.context; min_score, skip_trigrams, hashed, result.sources, min_source_score, add_sources, add_scores, kwargs...) end \ No newline at end of file diff --git a/src/Experimental/RAGTools/evaluation.jl b/src/Experimental/RAGTools/evaluation.jl index 1ef396fa..fb69c87a 100644 --- a/src/Experimental/RAGTools/evaluation.jl +++ b/src/Experimental/RAGTools/evaluation.jl @@ -143,15 +143,15 @@ function score_retrieval_rank(orig_context::AbstractString, end """ - run_qa_evals(qa_item::QAEvalItem, ctx::RAGDetails; verbose::Bool = true, + run_qa_evals(qa_item::QAEvalItem, ctx::RAGResult; verbose::Bool = true, parameters_dict::Dict{Symbol, <:Any}, judge_template::Symbol = :RAGJudgeAnswerFromContext, model_judge::AbstractString, api_kwargs::NamedTuple = NamedTuple()) -> QAEvalResult -Evaluates a single `QAEvalItem` using RAG details (`RAGDetails`) and returns a `QAEvalResult` structure. This function assesses the relevance and accuracy of the answers generated in a QA evaluation context. +Evaluates a single `QAEvalItem` using RAG details (`RAGResult`) and returns a `QAEvalResult` structure. This function assesses the relevance and accuracy of the answers generated in a QA evaluation context. # Arguments - `qa_item::QAEvalItem`: The QA evaluation item containing the question and its answer. -- `ctx::RAGDetails`: The context used for generating the QA pair, including the original context and the answers. +- `ctx::RAGResult`: The RAG result used for generating the QA pair, including the original context and the answers. Comes from `airag(...; return_context=true)` - `verbose::Bool`: If `true`, enables verbose logging. Defaults to `true`. - `parameters_dict::Dict{Symbol, Any}`: Track any parameters used for later evaluations. Keys must be Symbols. @@ -173,13 +173,13 @@ Evaluates a single `QAEvalItem` using RAG details (`RAGDetails`) and returns a ` Evaluating a QA pair using a specific context and model: ```julia qa_item = QAEvalItem(question="What is the capital of France?", answer="Paris", context="France is a country in Europe.") -ctx = RAGDetails(source="Wikipedia", context="France is a country in Europe.", answer="Paris") +ctx = RAGResult(source="Wikipedia", context="France is a country in Europe.", answer="Paris") parameters_dict = Dict("param1" => "value1", "param2" => "value2") eval_result = run_qa_evals(qa_item, ctx, parameters_dict=parameters_dict, model_judge="MyAIJudgeModel") ``` """ -function run_qa_evals(qa_item::QAEvalItem, ctx::RAGDetails; +function run_qa_evals(qa_item::QAEvalItem, ctx::RAGResult; verbose::Bool = true, parameters_dict::Dict{Symbol, <:Any} = Dict{Symbol, Any}(), judge_template::Symbol = :RAGJudgeAnswerFromContextShort, model_judge::AbstractString = PT.MODEL_CHAT, @@ -187,13 +187,13 @@ function run_qa_evals(qa_item::QAEvalItem, ctx::RAGDetails; retrieval_score = score_retrieval_hit(qa_item.context, ctx.context) retrieval_rank = score_retrieval_rank(qa_item.context, ctx.context) - # Note we could evaluate if RAGDetails and QAEvalItem are at least using the same sources etc. + # Note we could evaluate if RAGResult and QAEvalItem are at least using the same sources etc. answer_score = try msg = aiextract(judge_template; model = model_judge, verbose, ctx.context, ctx.question, - ctx.answer, + answer = ctx.final_answer, return_type = JudgeAllScores, api_kwargs) final_rating = if msg.content isa AbstractDict && haskey(msg.content, :final_rating) # if return type parsing failed @@ -211,7 +211,7 @@ function run_qa_evals(qa_item::QAEvalItem, ctx::RAGDetails; qa_item.source, qa_item.context, qa_item.question, - ctx.answer, + answer = ctx.final_answer, retrieval_score, retrieval_rank, answer_score, @@ -266,13 +266,13 @@ function run_qa_evals(index::AbstractChunkIndex, qa_items::AbstractVector{<:QAEv # Run evaluations in parallel results = asyncmap(qa_items) do qa_item # Generate an answer -- often you want the model_judge to be the highest quality possible, eg, "GPT-4 Turbo" (alias "gpt4t) - msg, ctx = airag(index; qa_item.question, return_details = true, + ragresult = airag(index; qa_item.question, return_all = true, verbose, api_kwargs, airag_kwargs...) # Evaluate the response # Note: you can log key parameters for easier analysis later run_qa_evals(qa_item, - ctx; + ragresult; parameters_dict, verbose, api_kwargs, diff --git a/src/Experimental/RAGTools/generation.jl b/src/Experimental/RAGTools/generation.jl index afd62623..25e65d5c 100644 --- a/src/Experimental/RAGTools/generation.jl +++ b/src/Experimental/RAGTools/generation.jl @@ -1,14 +1,28 @@ -# stub to be replaced within the package extension -function _normalize end """ - build_context(index::AbstractChunkIndex, reranked_candidates::CandidateChunks; chunks_window_margin::Tuple{Int, Int}) -> Vector{String} + ContextEnumerator <: AbstractContextBuilder -Build context strings for each position in `reranked_candidates` considering a window margin around each position. +Default method for `build_context!` method. It simply enumerates the context snippets around each position in `candidates`. When possibly, it will add surrounding chunks (from the same source). +""" +struct ContextEnumerator <: AbstractContextBuilder end + +""" + build_context(contexter::ContextEnumerator, + index::AbstractChunkIndex, candidates::CandidateChunks; + verbose::Bool = true, + chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...) + + build_context!(contexter::ContextEnumerator, + index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...) + +Build context strings for each position in `candidates` considering a window margin around each position. +If mutating version is used (`build_context!`), it will use `result.reranked_candidates` to update the `result.context` field. # Arguments -- `reranked_candidates::CandidateChunks`: Candidate chunks which contain positions to extract context from. +- `contexter::ContextEnumerator`: The method to use for building the context. Enumerates the snippets. - `index::ChunkIndex`: The index containing chunks and sources. +- `candidates::CandidateChunks`: Candidate chunks which contain positions to extract context from. +- `verbose::Bool`: If `true`, enables verbose logging. - `chunks_window_margin::Tuple{Int, Int}`: A tuple indicating the margin (before, after) around each position to include in the context. Defaults to `(1,1)`, which means 1 preceding and 1 suceeding chunk will be included. With `(0,0)`, only the matching chunks will be included. @@ -19,16 +33,21 @@ Build context strings for each position in `reranked_candidates` considering a w ```julia index = ChunkIndex(...) # Assuming a proper index is defined candidates = CandidateChunks(index.id, [2, 4], [0.1, 0.2]) -context = build_context(index, candidates; chunks_window_margin=(0, 1)) # include only one following chunk for each matching chunk +context = build_context(ContextEnumerator(), index, candidates; chunks_window_margin=(0, 1)) # include only one following chunk for each matching chunk ``` """ -function build_context(index::AbstractChunkIndex, reranked_candidates::CandidateChunks; - chunks_window_margin::Tuple{Int, Int} = (1, 1)) +function build_context(contexter::ContextEnumerator, + index::AbstractChunkIndex, candidates::CandidateChunks; + verbose::Bool = true, + chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...) + ## Checks @assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative" + context = String[] - for (i, position) in enumerate(reranked_candidates.positions) + for (i, position) in enumerate(candidates.positions) chunks_ = chunks(index)[max(1, position - chunks_window_margin[1]):min(end, position + chunks_window_margin[2])] + ## Check if surrounding chunks are from the same source is_same_source = sources(index)[max(1, position - chunks_window_margin[1]):min(end, position + chunks_window_margin[2])] .== sources(index)[position] push!(context, "$(i). $(join(chunks_[is_same_source], "\n"))") @@ -36,58 +55,398 @@ function build_context(index::AbstractChunkIndex, reranked_candidates::Candidate return context end +function build_context!(contexter::AbstractContextBuilder, + index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...) + throw(ArgumentError("Contexter $(typeof(contexter)) not implemented")) +end + +# Mutating version that dispatches on the result to the underlying implementation +function build_context!(contexter::ContextEnumerator, + index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...) + result.context = build_context(contexter, index, result.reranked_candidates; kwargs...) + return result +end + +## First step: Answerer + """ - airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext; - question::AbstractString, - top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0, - tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto, - rerank_strategy::RerankingStrategy = Passthrough(), - model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT, - model_metadata::String = PT.MODEL_CHAT, - metadata_template::Symbol = :RAGExtractMetadataShort, - chunks_window_margin::Tuple{Int, Int} = (1, 1), - return_details::Bool = false, verbose::Bool = true, - rerank_kwargs::NamedTuple = NamedTuple(), + SimpleAnswerer <: AbstractAnswerer + +Default method for `answer!` method. Generates an answer using the `aigenerate` function with the provided context and question. +""" +struct SimpleAnswerer <: AbstractAnswerer end + +function answer!( + answerer::AbstractAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult; + kwargs...) + throw(ArgumentError("Answerer $(typeof(answerer)) not implemented")) +end + +""" + answer!( + answerer::SimpleAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult; + model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true, + template::Symbol = :RAGAnswerFromContext, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Generates an answer using the `aigenerate` function with the provided `result.context` and `result.question`. + +# Returns +- Mutated `result` with `result.answer` and the full conversation saved in `result.conversations[:answer]` + +# Arguments +- `answerer::SimpleAnswerer`: The method to use for generating the answer. Uses `aigenerate`. +- `index::AbstractChunkIndex`: The index containing chunks and sources. +- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for. +- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. +- `verbose::Bool`: If `true`, enables verbose logging. +- `template::Symbol`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerFromContext`. +- `cost_tracker`: An atomic counter to track the cost of the operation. + +""" +function answer!( + answerer::SimpleAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult; + model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true, + template::Symbol = :RAGAnswerFromContext, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + ## Checks + placeholders = only(aitemplates(template)).variables # only one template should be found + @assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(template) is not suitable. It must have placeholders: `question` and `context`." + ## + (; context, question) = result + conv = aigenerate(template; question, + context = join(context, "\n\n"), model, verbose = false, + return_all = true, + kwargs...) + msg = conv[end] + result.answer = strip(msg.content) + result.conversations[:answer] = conv + ## Increment the cost tracker + Threads.atomic_add!(cost_tracker, msg.cost) + verbose && + @info "Done generating the answer. Cost: \$$(round(msg.cost,digits=3))" + + return result +end + +## Refine +""" + NoRefiner <: AbstractRefiner + +Default method for `refine!` method. A passthrough option that returns the `result.answer` without any changes. +""" +struct NoRefiner <: AbstractRefiner end + +""" + SimpleRefiner <: AbstractRefiner + +Refines the answer using the same context previously provided via the provided prompt template. +""" +struct SimpleRefiner <: AbstractRefiner end + +function refine!( + refiner::AbstractRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; + kwargs...) + throw(ArgumentError("Refiner $(typeof(refiner)) not implemented")) +end + +""" + refine!( + refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; + kwargs...) + +Simple no-op function for `refine`. It simply copies the `result.answer` and `result.conversations[:answer]` without any changes. +""" +function refine!( + refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; + kwargs...) + result.final_answer = result.answer + if haskey(result.conversations, :answer) + result.conversations[:final_answer] = result.conversations[:answer] + end + return result +end + +""" + refine!( + refiner::SimpleRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; + verbose::Bool = true, + model::AbstractString = PT.MODEL_CHAT, + template::Symbol = :RAGAnswerRefiner, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Give model a chance to refine the answer (using the same or different context than previously provided). + +This method uses the same context as the original answer, however, it can be modified to do additional retrieval and use a different context. + +# Returns +- Mutated `result` with `result.final_answer` and the full conversation saved in `result.conversations[:final_answer]` + +# Arguments +- `refiner::SimpleRefiner`: The method to use for refining the answer. Uses `aigenerate`. +- `index::AbstractChunkIndex`: The index containing chunks and sources. +- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for. +- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. +- `verbose::Bool`: If `true`, enables verbose logging. +- `template::Symbol`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerRefiner`. +- `cost_tracker`: An atomic counter to track the cost of the operation. +""" +function refine!( + refiner::SimpleRefiner, index::AbstractChunkIndex, result::AbstractRAGResult; + verbose::Bool = true, + model::AbstractString = PT.MODEL_CHAT, + template::Symbol = :RAGAnswerRefiner, + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + ## Checks + placeholders = only(aitemplates(template)).variables # only one template should be found + @assert (:query in placeholders)&&(:answer in placeholders) && + (:context in placeholders) "Provided RAG Template $(template) is not suitable. It must have placeholders: `query`, `answer` and `context`." + ## + (; answer, question, context) = result + conv = aigenerate(template; query = question, + context = join(context, "\n\n"), answer, model, verbose = false, + return_all = true, + kwargs...) + msg = conv[end] + result.final_answer = strip(msg.content) + result.conversations[:final_answer] = conv + + ## Increment the cost + Threads.atomic_add!(cost_tracker, msg.cost) + verbose && + @info "Done refining the answer. Cost: \$$(round(msg.cost,digits=3))" + + return result +end + +""" + NoPostprocessor <: AbstractPostprocessor + +Default method for `postprocess!` method. A passthrough option that returns the `result` without any changes. + +Overload this method to add custom postprocessing steps, eg, logging, saving conversations to disk, etc. +""" +struct NoPostprocessor <: AbstractPostprocessor end + +function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractChunkIndex, + result::AbstractRAGResult; kwargs...) + throw(ArgumentError("Postprocessor $(typeof(postprocessor)) not implemented")) +end + +function postprocess!( + ::NoPostprocessor, index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...) + return result +end + +### Overall types for `generate` +""" + SimpleGenerator <: AbstractGenerator + +Default implementation for `generate`. It simply enumerates context snippets and runs `aigenerate` (no refinement). + +It uses `ContextEnumerator`, `SimpleAnswerer`, `NoRefiner`, and `NoPostprocessor` as default `contexter`, `answerer`, `refiner`, and `postprocessor`. +""" +@kwdef mutable struct SimpleGenerator <: AbstractGenerator + contexter::AbstractContextBuilder = ContextEnumerator() + answerer::AbstractAnswerer = SimpleAnswerer() + refiner::AbstractRefiner = NoRefiner() + postprocessor::AbstractPostprocessor = NoPostprocessor() +end + +""" + AdvancedGenerator <: AbstractGenerator + +Default implementation for `generate!`. It simply enumerates context snippets and runs `aigenerate` (no refinement). + +It uses `ContextEnumerator`, `SimpleAnswerer`, `SimpleRefiner`, and `NoPostprocessor` as default `contexter`, `answerer`, `refiner`, and `postprocessor`. +""" +@kwdef mutable struct AdvancedGenerator <: AbstractGenerator + contexter::AbstractContextBuilder = ContextEnumerator() + answerer::AbstractAnswerer = SimpleAnswerer() + refiner::AbstractRefiner = SimpleRefiner() + postprocessor::AbstractPostprocessor = NoPostprocessor() +end + +""" + generate!( + generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult; + verbose::Integer = 1, api_kwargs::NamedTuple = NamedTuple(), - aiembed_kwargs::NamedTuple = NamedTuple(), - aigenerate_kwargs::NamedTuple = NamedTuple(), - aiextract_kwargs::NamedTuple = NamedTuple(), + contexter::AbstractContextBuilder = generator.contexter, + contexter_kwargs::NamedTuple = NamedTuple(), + answerer::AbstractAnswerer = generator.answerer, + answerer_kwargs::NamedTuple = NamedTuple(), + refiner::AbstractRefiner = generator.refiner, + refiner_kwargs::NamedTuple = NamedTuple(), + postprocessor::AbstractPostprocessor = generator.postprocessor, + postprocessor_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) -Generates a response for a given question using a Retrieval-Augmented Generation (RAG) approach. +Generate the response using the provided `generator` and the `index` and `result`. +It is the second step in the RAG pipeline (after `retrieve`) + +Returns the mutated `result` with the `result.final_answer` and the full conversation saved in `result.conversations[:final_answer]`. + +# Notes +- The default flow is `build_context!` -> `answer!` -> `refine!` -> `postprocess!`. +- `contexter` is the method to use for building the context, eg, simply enumerate the context chunks with `ContextEnumerator`. +- `answerer` is the standard answer generation step with LLMs. +- `refiner` step allows the LLM to critique itself and refine its own answer. +- `postprocessor` step allows for additional processing of the answer, eg, logging, saving conversations, etc. +- All of its sub-routines operate by mutating the `result` object (and adding their part). +- Discover available sub-types for each step with `subtypes(AbstractRefiner)` and similar for other abstract types. + +# Arguments +- `generator::AbstractGenerator`: The `generator` to use for generating the answer. Can be `SimpleGenerator` or `AdvancedGenerator`. +- `index::AbstractChunkIndex`: The index containing chunks and sources. +- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for. +- `verbose::Integer`: If >0, enables verbose logging. +- `api_kwargs::NamedTuple`: API parameters that will be forwarded to ALL of the API calls (`aiembed`, `aigenerate`, and `aiextract`). +- `contexter::AbstractContextBuilder`: The method to use for building the context. Defaults to `generator.contexter`, eg, `ContextEnumerator`. +- `contexter_kwargs::NamedTuple`: API parameters that will be forwarded to the `contexter` call. +- `answerer::AbstractAnswerer`: The method to use for generating the answer. Defaults to `generator.answerer`, eg, `SimpleAnswerer`. +- `answerer_kwargs::NamedTuple`: API parameters that will be forwarded to the `answerer` call. Examples: + - `model`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. + - `template`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerFromContext`. +- `refiner::AbstractRefiner`: The method to use for refining the answer. Defaults to `generator.refiner`, eg, `NoRefiner`. +- `refiner_kwargs::NamedTuple`: API parameters that will be forwarded to the `refiner` call. + - `model`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. + - `template`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerRefiner`. +- `postprocessor::AbstractPostprocessor`: The method to use for postprocessing the answer. Defaults to `generator.postprocessor`, eg, `NoPostprocessor`. +- `postprocessor_kwargs::NamedTuple`: API parameters that will be forwarded to the `postprocessor` call. +- `cost_tracker`: An atomic counter to track the total cost of the operations. + +See also: `retrieve`, `build_context!`, `ContextEnumerator`, `answer!`, `SimpleAnswerer`, `refine!`, `NoRefiner`, `SimpleRefiner`, `postprocess!`, `NoPostprocessor` + +# Examples +```julia +Assume we already have `index` + +question = "What are the best practices for parallel computing in Julia?" + +# Retrieve the relevant chunks - returns RAGResult +result = retrieve(index, question) + +# Generate the answer using the default generator, mutates the same result +result = generate!(index, result) + +``` +""" +function generate!( + generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult; + verbose::Integer = 1, + api_kwargs::NamedTuple = NamedTuple(), + contexter::AbstractContextBuilder = generator.contexter, + contexter_kwargs::NamedTuple = NamedTuple(), + answerer::AbstractAnswerer = generator.answerer, + answerer_kwargs::NamedTuple = NamedTuple(), + refiner::AbstractRefiner = generator.refiner, + refiner_kwargs::NamedTuple = NamedTuple(), + postprocessor::AbstractPostprocessor = generator.postprocessor, + postprocessor_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + + ## Build the context + contexter_kwargs_ = isempty(api_kwargs) ? contexter_kwargs : + merge(contexter_kwargs, (; api_kwargs)) + result = build_context!(contexter, + index, result; verbose = (verbose > 1), cost_tracker, contexter_kwargs_...) + + ## LLM call to answer + answerer_kwargs_ = isempty(api_kwargs) ? answerer_kwargs : + merge(answerer_kwargs, (; api_kwargs)) + result = answer!( + answerer, index, result; verbose = (verbose > 1), cost_tracker, answerer_kwargs_...) -The function selects relevant chunks from an `ChunkIndex`, optionally filters them based on metadata tags, reranks them, and then uses these chunks to construct a context for generating a response. + ## Refine the answer + refiner_kwargs_ = isempty(api_kwargs) ? refiner_kwargs : + merge(refiner_kwargs, (; api_kwargs)) + result = refine!( + refiner, index, result; verbose = (verbose > 1), cost_tracker, refiner_kwargs_...) + + ## Postprocessing + postprocessor_kwargs_ = isempty(api_kwargs) ? postprocessor_kwargs : + merge(postprocessor_kwargs, (; api_kwargs)) + result = postprocess!(postprocessor, index, result; verbose = (verbose > 1), + cost_tracker, postprocessor_kwargs_...) + + return result # mutated result +end + +# Set default behavior +DEFAULT_GENERATOR = SimpleGenerator() +function generate!(index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...) + return generate!(DEFAULT_GENERATOR, index, result; kwargs...) +end + +### Overarching + +""" + RAGConfig <: AbstractRAGConfig + +Default configuration for RAG. It uses `SimpleIndexer`, `SimpleRetriever`, and `SimpleGenerator` as default components. + +To customize the components, replace corresponding fields for each step of the RAG pipeline (eg, use `subtypes(AbstractIndexBuilder)` to find the available options). +""" +@kwdef mutable struct RAGConfig <: AbstractRAGConfig + indexer::AbstractIndexBuilder = SimpleIndexer() + retriever::AbstractRetriever = SimpleRetriever() + generator::AbstractGenerator = SimpleGenerator() +end + +""" + airag(cfg::AbstractRAGConfig, index::AbstractChunkIndex; + question::AbstractString, + verbose::Integer = 1, return_all::Bool = false, + api_kwargs::NamedTuple = NamedTuple(), + retriever::AbstractRetriever = cfg.retriever, + retriever_kwargs::NamedTuple = NamedTuple(), + generator::AbstractGenerator = cfg.generator, + generator_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) + +High-level wrapper for Retrieval-Augmented Generation (RAG), it combines together the `retrieve` and `generate!` steps which you can customize if needed. + +The simplest version first finds the relevant chunks in `index` for the `question` and then sends these chunks to the AI model to help with generating a response to the `question`. + +To customize the components, replace the types (`retriever`, `generator`) of the corresponding step of the RAG pipeline - or go into sub-routines within the steps. +Eg, use `subtypes(AbstractRetriever)` to find the available options. # Arguments +- `cfg::AbstractRAGConfig`: The configuration for the RAG pipeline. Defaults to `RAGConfig()`, where you can swap sub-types to customize the pipeline. - `index::AbstractChunkIndex`: The chunk index to search for relevant text. -- `rag_template::Symbol`: Template for the RAG model, defaults to `:RAGAnswerFromContext`. - `question::AbstractString`: The question to be answered. -- `top_k::Int`: Number of top candidates to retrieve based on embedding similarity. -- `top_n::Int`: Number of candidates to return after reranking. -- `minimum_similarity::AbstractFloat`: Minimum similarity threshold (between -1 and 1) for filtering chunks based on embedding similarity. Defaults to -1.0. -- `tag_filter::Union{Symbol, Vector{String}, Regex}`: Mechanism for filtering chunks based on tags (either automatically detected, specific tags, or a regex pattern). Disabled by setting to `nothing`. -- `rerank_strategy::RerankingStrategy`: Strategy for reranking the retrieved chunks. Defaults to `Passthrough()`. Use `CohereRerank` for better results (requires `COHERE_API_KEY` to be set) -- `model_embedding::String`: Model used for embedding the question, default is `PT.MODEL_EMBEDDING`. -- `model_chat::String`: Model used for generating the final response, default is `PT.MODEL_CHAT`. -- `model_metadata::String`: Model used for extracting metadata, default is `PT.MODEL_CHAT`. -- `metadata_template::Symbol`: Template for the metadata extraction process from the question, defaults to: `:RAGExtractMetadataShort` -- `chunks_window_margin::Tuple{Int,Int}`: The window size around each chunk to consider for context building. See `?build_context` for more information. -- `return_details::Bool`: If `true`, returns the details used for RAG along with the response. -- `verbose::Bool`: If `true`, enables verbose logging. +- `return_all::Bool`: If `true`, returns the details used for RAG along with the response. +- `verbose::Integer`: If `>0`, enables verbose logging. The higher the number, the more nested functions will log. - `api_kwargs`: API parameters that will be forwarded to ALL of the API calls (`aiembed`, `aigenerate`, and `aiextract`). -- `aiembed_kwargs`: API parameters that will be forwarded to the `aiembed` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aiembed_kwargs = (; api_kwargs = (; x=1))`. -- `aigenerate_kwargs`: API parameters that will be forwarded to the `aigenerate` call. If you need to provide `api_kwargs` only to this function, simply add them as a keyword argument, eg, `aigenerate_kwargs = (; api_kwargs = (; temperature=0.3))`. -- `aiextract_kwargs`: API parameters that will be forwarded to the `aiextract` call for the metadata extraction. +- `retriever::AbstractRetriever`: The retriever to use for finding relevant chunks. Defaults to `cfg.retriever`, eg, `SimpleRetriever` (with no question rephrasing). +- `retriever_kwargs::NamedTuple`: API parameters that will be forwarded to the `retriever` call. Examples of important ones: + - `top_k::Int`: Number of top candidates to retrieve based on embedding similarity. + - `top_n::Int`: Number of candidates to return after reranking. + - `tagger::AbstractTagger`: Tagger to use for tagging the chunks. Defaults to `NoTagger()`. + - `tagger_kwargs::NamedTuple`: API parameters that will be forwarded to the `tagger` call. You could provide the explicit tags directly with `PassthroughTagger` and `tagger_kwargs = (; tags = ["tag1", "tag2"])`. +- `generator::AbstractGenerator`: The generator to use for generating the answer. Defaults to `cfg.generator`, eg, `SimpleGenerator`. +- `generator_kwargs::NamedTuple`: API parameters that will be forwarded to the `generator` call. Examples of important ones: + - `answerer_kwargs::NamedTuple`: API parameters that will be forwarded to the `answerer` call. Examples: + - `model`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. + - `template`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerFromContext`. + - `refiner::AbstractRefiner`: The method to use for refining the answer. Defaults to `generator.refiner`, eg, `NoRefiner`. + - `refiner_kwargs::NamedTuple`: API parameters that will be forwarded to the `refiner` call. + - `model`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`. + - `template`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerRefiner`. +- `cost_tracker`: An atomic counter to track the total cost of the operations (if you want to track the cost of multiple pipeline runs - it passed around in the pipeline). # Returns -- If `return_details` is `false`, returns the generated message (`msg`). -- If `return_details` is `true`, returns a tuple of the generated message (`msg`) and the `RAGDetails` for context (`rag_details`). +- If `return_all` is `false`, returns the generated message (`msg`). +- If `return_all` is `true`, returns the detail of the full pipeline in `RAGResult` (see the docs). -# Notes -- The function first finds the closest chunks to the question embedding, then optionally filters these based on tags. After that, it reranks the candidates and builds a context for the RAG model. -- The `tag_filter` can be used to refine the search. If set to `:auto`, it attempts to automatically determine relevant tags (if `index` has them available). -- The `chunks_window_margin` allows including surrounding chunks for richer context, considering they are from the same source. -- The function currently supports only single `ChunkIndex`. +See also `build_index`, `retrieve`, `generate!`, `RAGResult` # Examples @@ -95,15 +454,12 @@ Using `airag` to get a response for a question: ```julia index = build_index(...) # create an index question = "How to make a barplot in Makie.jl?" -msg = airag(index, :RAGAnswerFromContext; question) - -# or simply msg = airag(index; question) ``` -To understand the details of the RAG process, use `return_details=true` +To understand the details of the RAG process, use `return_all=true` ```julia -msg, details = airag(index; question, return_details = true) +msg, details = airag(index; question, return_all = true) # details is a RAGDetails object with all the internal steps of the `airag` function ``` @@ -113,103 +469,80 @@ It also includes annotations of which context was used for each part of the resp PT.pprint(details) ``` -See also `build_index`, `build_context`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank`, `annotate_support` +Example with advanced retrieval (with question rephrasing and reranking (requires `COHERE_API_KEY`). +We will obtain top 100 chunks from embeddings (`top_k`) and top 5 chunks from reranking (`top_n`). +In addition, it will be done with a "custom" locally-hosted model. + +```julia +cfg = RAGConfig(; retriever = AdvancedRetriever()) + +# kwargs will be big and nested, let's prepare them upfront +# we specify "custom" model for each component that calls LLM +kwargs = ( + retriever_kwargs = (; + top_k = 100, + top_n = 5, + rephraser_kwargs = (; + model = "custom"), + embedder_kwargs = (; + model = "custom"), + tagger_kwargs = (; + model = "custom")), + generator_kwargs = (; + answerer_kwargs = (; + model = "custom"), + refiner_kwargs = (; + model = "custom")), + api_kwargs = (; + url = "http://localhost:8080")) + +result = airag(cfg, index, question; kwargs...) +``` """ -function airag(index::AbstractChunkIndex, rag_template::Symbol = :RAGAnswerFromContext; +function airag(cfg::AbstractRAGConfig, index::AbstractChunkIndex; question::AbstractString, - top_k::Int = 100, top_n::Int = 5, minimum_similarity::AbstractFloat = -1.0, - tag_filter::Union{Symbol, Vector{String}, Regex, Nothing} = :auto, - rerank_strategy::RerankingStrategy = Passthrough(), - model_embedding::String = PT.MODEL_EMBEDDING, model_chat::String = PT.MODEL_CHAT, - model_metadata::String = PT.MODEL_CHAT, - metadata_template::Symbol = :RAGExtractMetadataShort, - chunks_window_margin::Tuple{Int, Int} = (1, 1), - return_details::Bool = false, verbose::Bool = true, - rerank_kwargs::NamedTuple = NamedTuple(), + verbose::Integer = 1, return_all::Bool = false, api_kwargs::NamedTuple = NamedTuple(), - aiembed_kwargs::NamedTuple = NamedTuple(), - aigenerate_kwargs::NamedTuple = NamedTuple(), - aiextract_kwargs::NamedTuple = NamedTuple(), - kwargs...) - ## Note: Supports only single ChunkIndex for now - ## Checks - @assert !(tag_filter isa Symbol && tag_filter != :auto) "Only `:auto`, `Vector{String}`, or `Regex` are supported for `tag_filter`" - @assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative" - placeholders = only(aitemplates(rag_template)).variables # only one template should be found - @assert (:question in placeholders)&&(:context in placeholders) "Provided RAG Template $(rag_template) is not suitable. It must have placeholders: `question` and `context`." - - ## Embedding - joined_kwargs = isempty(api_kwargs) ? aiembed_kwargs : - merge(aiembed_kwargs, (; api_kwargs)) - question_emb = aiembed(question, - _normalize; - model = model_embedding, - verbose, joined_kwargs...).content .|> Float32 # no need for Float64 - emb_candidates = find_closest(index, question_emb; top_k, minimum_similarity) - - tag_candidates = if tag_filter == :auto && !isnothing(tags(index)) && - !isempty(model_metadata) - _check_aiextract_capability(model_metadata) - joined_kwargs = isempty(api_kwargs) ? aiextract_kwargs : - merge(aiextract_kwargs, (; api_kwargs)) - # extract metadata via LLM call - metadata_ = try - msg = aiextract(metadata_template; return_type = MaybeMetadataItems, - text = question, - instructions = "In addition to extracted items, suggest 2-3 filter keywords that could be relevant to answer this question.", - verbose, model = model_metadata, joined_kwargs...) - ## eg, ["software:::pandas", "language:::python", "julia_package:::dataframes"] - ## we split it and take only the keyword, not the category - metadata_extract(msg.content.items) |> - x -> split.(x, ":::") |> x -> getindex.(x, 2) - catch e - String[] - end - find_tags(index, metadata_) - elseif tag_filter isa Union{Vector{String}, Regex} - find_tags(index, tag_filter) - elseif isnothing(tag_filter) - nothing - else - ## not filtering -- use all rows and ignore this - nothing - end + retriever::AbstractRetriever = cfg.retriever, + retriever_kwargs::NamedTuple = NamedTuple(), + generator::AbstractGenerator = cfg.generator, + generator_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0)) - filtered_candidates = isnothing(tag_candidates) ? emb_candidates : - (emb_candidates & tag_candidates) - reranked_candidates = rerank(rerank_strategy, - index, - question, - filtered_candidates; - top_n, - verbose = false, rerank_kwargs...) + ## Retrieve top context + retriever_kwargs_ = isempty(api_kwargs) ? retriever_kwargs : + merge(retriever_kwargs, (; api_kwargs)) + result = retrieve( + retriever, index, question; verbose = verbose - 1, cost_tracker, retriever_kwargs_...) - ## Build the context - context = build_context(index, reranked_candidates; chunks_window_margin) - - ## LLM call - joined_kwargs = isempty(api_kwargs) ? aigenerate_kwargs : - merge(aigenerate_kwargs, (; api_kwargs)) - msg = aigenerate(rag_template; question, - context = join(context, "\n\n"), model = model_chat, verbose, - joined_kwargs...) - - if return_details # for evaluation - rag_details = RAGDetails(; - question, - rephrased_question = [question], - answer = msg.content, - refined_answer = msg.content, - context, - sources = sources(index)[reranked_candidates.positions], - emb_candidates, - tag_candidates, - filtered_candidates, - reranked_candidates) - return msg, rag_details + ## Generate the response + generator_kwargs_ = isempty(api_kwargs) ? generator_kwargs : + merge(generator_kwargs, (; api_kwargs)) + result = generate!(generator, index, result; verbose = verbose - 1, cost_tracker, + generator_kwargs_...) + + verbose > 0 && + @info "Done with RAG. Total cost: \$$(round(cost_tracker[], digits=3))" + + ## Return `RAGResult` or more user-friendly `AIMessage` + output = if return_all + result + elseif haskey(result.conversations, :final_answer) && + !isempty(result.conversations[:final_answer]) + result.conversations[:final_answer][end] + elseif haskey(result.conversations, :answer) && + !isempty(result.conversations[:answer]) + result.conversations[:answer][end] else - return msg + throw(ArgumentError("No conversation found in the result")) end + return output +end + +# Default behavior +const DEFAULT_RAG_CONFIG = RAGConfig() +function airag(index::AbstractChunkIndex; question::AbstractString, kwargs...) + return airag(DEFAULT_RAG_CONFIG, index; question, kwargs...) end # Special method to pretty-print the airag results diff --git a/src/Experimental/RAGTools/preparation.jl b/src/Experimental/RAGTools/preparation.jl index 87dc1fa4..c5226ce2 100644 --- a/src/Experimental/RAGTools/preparation.jl +++ b/src/Experimental/RAGTools/preparation.jl @@ -1,114 +1,179 @@ -### Preparation +## Preparation Stage + +### Chunking Types + +""" + FileChunker <: AbstractChunker + +Chunker when you provide file paths to `get_chunks` functions. + +Ie, the inputs will be validated first (eg, file exists, etc) and then read into memory. + +Set as default chunker in `get_chunks` functions. +""" +struct FileChunker <: AbstractChunker end + +""" + TextChunker <: AbstractChunker + +Chunker when you provide text to `get_chunks` functions. Inputs are directly chunked +""" +struct TextChunker <: AbstractChunker end + +### Embedding Types +""" + BatchEmbedder <: AbstractEmbedder + +Default embedder for `get_embeddings` functions. It passes individual documents to be embedded in chunks to `aiembed`. +""" +struct BatchEmbedder <: AbstractEmbedder end + +### Tagging Types +""" + NoTagger <: AbstractTagger + +No-op tagger for `get_tags` functions. It returns (`nothing`, `nothing`). +""" +struct NoTagger <: AbstractTagger end + +""" + PassthroughTagger <: AbstractTagger + +Tagger for `get_tags` functions, which passes `tags` directly as Vector of Vectors of strings (ie, `tags[i]` is the tags for `docs[i]`). +""" +struct PassthroughTagger <: AbstractTagger end + +""" + OpenTagger <: AbstractTagger + +Tagger for `get_tags` functions, which generates possible tags for each chunk via `aiextract`. +You can customize it via prompt template (default: `:RAGExtractMetadataShort`), but it's quite open-ended (ie, AI decides the possible tags). +""" +struct OpenTagger <: AbstractTagger end + # Types used to extract `tags` from document chunks -@kwdef struct MetadataItem +@kwdef struct Tag value::String category::String end -@kwdef struct MaybeMetadataItems - items::Union{Nothing, Vector{MetadataItem}} +@kwdef struct MaybeTags + items::Union{Nothing, Vector{Tag}} end +### Overall types for build_index """ - metadata_extract(item::MetadataItem) - metadata_extract(items::Vector{MetadataItem}) + SimpleIndexer <: AbstractIndexBuilder -Extracts the metadata item into a string of the form `category:::value` (lowercased and spaces replaced with underscores). +Default implementation for `build_index`. -# Example -```julia -msg = aiextract(:RAGExtractMetadataShort; return_type=MaybeMetadataItems, text="I like package DataFrames", instructions="None.") -metadata = metadata_extract(msg.content.items) -``` +It uses `TextChunker`, `BatchEmbedder`, and `NoTagger` as default chunker, embedder, and tagger. """ -function metadata_extract(item::MetadataItem) - "$(strip(item.category)):::$(strip(item.value))" |> lowercase |> - x -> replace(x, " " => "_") +@kwdef mutable struct SimpleIndexer <: AbstractIndexBuilder + chunker::AbstractChunker = TextChunker() + embedder::AbstractEmbedder = BatchEmbedder() + tagger::AbstractTagger = NoTagger() end -metadata_extract(items::Nothing) = String[] -metadata_extract(items::Vector{MetadataItem}) = metadata_extract.(items) -"Builds a matrix of tags and a vocabulary list. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" -function build_tags end -# Implementation in ext/RAGToolsExperimentalExt.jl +### Functions -"Build an index for RAG (Retriever-Augmented Generation) applications. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" -function build_index end +## "Build an index for RAG (Retriever-Augmented Generation) applications. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" +## function build_index end "Shortcut to LinearAlgebra.normalize. Provided in the package extension `RAGToolsExperimentalExt` (Requires SparseArrays and LinearAlgebra)" function _normalize end """ - get_chunks(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, - sources::Vector{<:AbstractString} = files_or_docs, + load_text(chunker::AbstractChunker, input; + kwargs...) + +Load text from `input` using the provided `chunker` + +Available chunkers: +- `FileChunker`: The function opens each file in `input` and reads its contents. +- `TextChunker`: The function assumes that `input` is a vector of strings to be chunked, you MUST provide corresponding `sources`. +""" +function load_text(chunker::AbstractChunker, input; + kwargs...) + throw(ArgumentError("Not implemented for chunker $(typeof(chunker))")) +end +function load_text(chunker::FileChunker, input::AbstractString; + source::AbstractString = input, kwargs...) + @assert isfile(input) "Path $input does not exist" + return read(input, String), source +end +function load_text(chunker::TextChunker, input::AbstractString; + source::AbstractString = input, kwargs...) + @assert length(source)<=512 "Each `source` should be less than 512 characters long. Detected: $(length(source)) characters. You must provide sources for each text when using `TextChunker`" + return input, source +end + +""" + get_chunks(chunker::AbstractChunker, + files_or_docs::Vector{<:AbstractString}; + sources::AbstractVector{<:AbstractString} = files_or_docs, verbose::Bool = true, - separators = ["\\n\\n", ". ", "\\n"], max_length::Int = 256) + separators = ["\\n\\n", ". ", "\\n", " "], max_length::Int = 256) Chunks the provided `files_or_docs` into chunks of maximum length `max_length` (if possible with provided `separators`). Supports two modes of operation: -- `reader=:files`: The function opens each file in `files_or_docs` and reads its content. -- `reader=:docs`: The function assumes that `files_or_docs` is a vector of strings to be chunked. +- `chunker = FileChunker()`: The function opens each file in `files_or_docs` and reads its contents. +- `chunker = TextChunker()`: The function assumes that `files_or_docs` is a vector of strings to be chunked, you MUST provide corresponding `sources`. # Arguments - `files_or_docs`: A vector of valid file paths OR string documents to be chunked. -- `reader`: A symbol indicating the type of input, can be either `:files` or `:docs`. Default is `:files`. -- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `[\\n\\n", ". ", "\\n"]`. +- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `[\\n\\n", ". ", "\\n", " "]`. + See `recursive_splitter` for more details. - `max_length`: The maximum length of each chunk (if possible with provided separators). Default is 256. - `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs` (for `reader=:files`) """ -function get_chunks(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, - sources::Vector{<:AbstractString} = files_or_docs, +function get_chunks(chunker::AbstractChunker, + files_or_docs::Vector{<:AbstractString}; + sources::AbstractVector{<:AbstractString} = files_or_docs, verbose::Bool = true, - separators = ["\n\n", ". ", "\n"], max_length::Int = 256) + separators = ["\n\n", ". ", "\n", " "], max_length::Int = 256) ## Check that all items must be existing files or strings - @assert reader in [:files, :docs] "Invalid `read` argument. Must be one of [:files, :docs]" - if reader == :files - @assert all(isfile, files_or_docs) "Some paths in `files_or_docs` don't exist (Check: $(join(filter(!isfile,files_or_docs),", "))" - else - @assert sources!=files_or_docs "When `reader=:docs`, vector of `sources` must be provided" - end - @assert isnothing(sources)||(length(sources) == length(files_or_docs)) "Length of `sources` must match length of `files_or_docs`" - @assert maximum(length.(sources))<=512 "Each source must be less than 512 characters long (Detected: $(maximum(length.(sources))))" + @assert (length(sources)==length(files_or_docs)) "Length of `sources` must match length of `files_or_docs`" output_chunks = Vector{SubString{String}}() output_sources = Vector{eltype(sources)}() # Do chunking first for i in eachindex(files_or_docs, sources) - # if reader == :files, we open the files and read them - doc_raw = if reader == :files - fn = files_or_docs[i] - (verbose > 0) && @info "Processing file: $fn" - read(fn, String) - else - files_or_docs[i] - end + doc_raw, source = load_text(chunker, files_or_docs[i]; source = sources[i]) isempty(doc_raw) && continue - # split into chunks, if you want to start simple - just do `split(text,"\n\n")` + # split into chunks by recursively trying the separators provided + # if you want to start simple - just do `split(text,"\n\n")` doc_chunks = PT.recursive_splitter(doc_raw, separators; max_length) .|> strip |> x -> filter(!isempty, x) # skip if no chunks found isempty(doc_chunks) && continue append!(output_chunks, doc_chunks) - append!(output_sources, fill(sources[i], length(doc_chunks))) + append!(output_sources, fill(source, length(doc_chunks))) end return output_chunks, output_sources end +function get_embeddings( + embedder::AbstractEmbedder, docs::AbstractVector{<:AbstractString}; kwargs...) + throw(ArgumentError("Not implemented for embedder $(typeof(embedder))")) +end + """ - get_embeddings(docs::Vector{<:AbstractString}; + get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:AbstractString}; verbose::Bool = true, cost_tracker = Threads.Atomic{Float64}(0.0), target_batch_size_length::Int = 80_000, ntasks::Int = 4 * Threads.nthreads(), kwargs...) + -Embeds a vector of `docs` using the provided model (kwarg `model`). +Embeds a vector of `docs` using the provided model (kwarg `model`) in a batched manner - `BatchEmbedder`. -Tries to batch embedding calls for roughly 80K characters per call (to avoid exceeding the API limit) but reduce network latency. +`BatchEmbedder` tries to batch embedding calls for roughly 80K characters per call (to avoid exceeding the API rate limit) to reduce network latency. # Notes - `docs` are assumed to be already chunked to the reasonable sizes that fit within the embedding context limit. @@ -125,8 +190,9 @@ Tries to batch embedding calls for roughly 80K characters per call (to avoid exc - `ntasks`: The number of tasks to use for asyncmap. Default is 4 * Threads.nthreads(). """ -function get_embeddings(docs::Vector{<:AbstractString}; +function get_embeddings(embedder::BatchEmbedder, docs::AbstractVector{<:AbstractString}; verbose::Bool = true, + model::AbstractString = PT.MODEL_EMBEDDING, cost_tracker = Threads.Atomic{Float64}(0.0), target_batch_size_length::Int = 80_000, ntasks::Int = 4 * Threads.nthreads(), @@ -134,10 +200,9 @@ function get_embeddings(docs::Vector{<:AbstractString}; ## check if extension is available ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt) if isnothing(ext) - error("you need to also import LinearAlgebra and SparseArrays to use this function") + error("You need to also import LinearAlgebra and SparseArrays to use this function") end verbose && @info "Embedding $(length(docs)) documents..." - model = hasproperty(kwargs, :model) ? kwargs.model : PT.MODEL_EMBEDDING # Notice that we embed multiple docs at once, not one by one # OpenAI supports embedding multiple documents to reduce the number of API calls/network latency time # We do batch them just in case the documents are too large (targeting at most 80K characters per call) @@ -148,9 +213,10 @@ function get_embeddings(docs::Vector{<:AbstractString}; msg = aiembed(docs_chunk, # LinearAlgebra.normalize but imported in RAGToolsExperimentalExt _normalize; + model, verbose = false, kwargs...) - Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model)) # track costs + Threads.atomic_add!(cost_tracker, msg.cost) # track costs msg.content end embeddings = hcat(embeddings...) .|> Float32 # flatten, columns are documents @@ -158,100 +224,176 @@ function get_embeddings(docs::Vector{<:AbstractString}; return embeddings end +### Tag Extraction + +function get_tags(tagger::AbstractTagger, docs::AbstractVector{<:AbstractString}; + kwargs...) + throw(ArgumentError("Not implemented for tagger $(typeof(tagger))")) +end + +""" + tags_extract(item::Tag) + tags_extract(tags::Vector{Tag}) + +Extracts the `Tag` item into a string of the form `category:::value` (lowercased and spaces replaced with underscores). + +# Example +```julia +msg = aiextract(:RAGExtractMetadataShort; return_type=MaybeTags, text="I like package DataFrames", instructions="None.") +metadata = tags_extract(msg.content.items) +``` +""" +function tags_extract(item::Tag) + "$(strip(item.category)):::$(strip(item.value))" |> lowercase |> + x -> replace(x, " " => "_") +end +tags_extract(items::Nothing) = String[] +tags_extract(items::Vector{Tag}) = tags_extract.(items) + +""" + get_tags(tagger::NoTagger, docs::AbstractVector{<:AbstractString}; + kwargs...) + +Simple no-op that skips any tagging of the documents +""" +function get_tags(tagger::NoTagger, docs::AbstractVector{<:AbstractString}; + kwargs...) + nothing +end + +""" + get_tags(tagger::PassthroughTagger, docs::AbstractVector{<:AbstractString}; + tags::AbstractVector{<:AbstractVector{<:AbstractString}}, + kwargs...) + +Pass `tags` directly as Vector of Vectors of strings (ie, `tags[i]` is the tags for `docs[i]`). +It then builds the vocabulary from the tags and returns both the tags in matrix form and the vocabulary. """ - get_metadata(docs::Vector{<:AbstractString}; +function get_tags(tagger::PassthroughTagger, docs::AbstractVector{<:AbstractString}; + tags::AbstractVector{<:AbstractVector{<:AbstractString}}, + kwargs...) + @assert length(docs)==length(tags) "Length of `docs` must match length of `tags`" + return tags +end + +""" + get_tags(tagger::OpenTagger, docs::AbstractVector{<:AbstractString}; verbose::Bool = true, cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) -Extracts metadata from a vector of `docs` using the provided model (kwarg `model`). +Extracts "tags" (metadata/keywords) from a vector of `docs` using the provided model (kwarg `model`). # Arguments - `docs`: A vector of strings to be embedded. - `verbose`: A boolean flag for verbose output. Default is `true`. -- `model`: The model to use for metadata extraction. Default is `PT.MODEL_CHAT`. -- `metadata_template`: A template to be used for metadata extraction. Default is `:RAGExtractMetadataShort`. +- `model`: The model to use for tags extraction. Default is `PT.MODEL_CHAT`. +- `template`: A template to be used for tags extraction. Default is `:RAGExtractMetadataShort`. - `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. - """ -function get_metadata(docs::Vector{<:AbstractString}; +function get_tags(tagger::OpenTagger, docs::AbstractVector{<:AbstractString}; verbose::Bool = true, - metadata_template::Symbol = :RAGExtractMetadataShort, + model::AbstractString = PT.MODEL_CHAT, + template::Symbol = :RAGExtractMetadataShort, cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) - model = hasproperty(kwargs, :model) ? kwargs.model : PT.MODEL_CHAT _check_aiextract_capability(model) + ## check if extension is available + ext = Base.get_extension(PromptingTools, :RAGToolsExperimentalExt) + if isnothing(ext) + error("You need to also import LinearAlgebra and SparseArrays to use this function") + end verbose && @info "Extracting metadata from $(length(docs)) documents..." - metadata = asyncmap(docs) do docs_chunk + tags_extracted = asyncmap(docs) do docs_chunk try - msg = aiextract(metadata_template; - return_type = MaybeMetadataItems, + msg = aiextract(template; + return_type = MaybeTags, text = docs_chunk, instructions = "None.", verbose = false, model, kwargs...) - Threads.atomic_add!(cost_tracker, PT.call_cost(msg, model)) # track costs - items = metadata_extract(msg.content.items) + Threads.atomic_add!(cost_tracker, msg.cost) # track costs + items = tags_extract(msg.content.items) catch String[] end end + verbose && - @info "Done extracting the metadata. Total cost: \$$(round(cost_tracker[],digits=3))" - return metadata + @info "Done extracting the tags. Total cost: \$$(round(cost_tracker[],digits=3))" + + return tags_extracted end """ - build_index(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, - separators = ["\\n\\n", ". ", "\\n"], max_length::Int = 256, - sources::Vector{<:AbstractString} = files_or_docs, + build_tags(tagger::AbstractTagger, chunk_tags::Nothing; kwargs...) + +No-op that skips any tag building, returning `nothing, nothing` + +Otherwise, it would build the sparse matrix and the vocabulary (requires `SparseArrays` and `LinearAlgebra` packages to be loaded). +""" +function build_tags(tagger::AbstractTagger, chunk_tags::Nothing; kwargs...) + nothing, nothing +end + +""" + build_index( + indexer::AbstractIndexBuilder, files_or_docs::Vector{<:AbstractString}; + verbose::Integer = 1, extras::Union{Nothing, AbstractVector} = nothing, - extract_metadata::Bool = false, verbose::Integer = 1, index_id = gensym("ChunkIndex"), - metadata_template::Symbol = :RAGExtractMetadataShort, - model_embedding::String = PT.MODEL_EMBEDDING, - model_metadata::String = PT.MODEL_CHAT, - embedding_kwargs::NamedTuple = NamedTuple(), - metadata_kwargs::NamedTuple = NamedTuple(), + chunker::AbstractChunker = indexer.chunker, + chunker_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = indexer.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = indexer.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cost_tracker = Threads.Atomic{Float64}(0.0)) -Build an index for RAG (Retriever-Augmented Generation) applications from the provided file paths. -The function processes each file, splits its content into chunks, embeds these chunks, -optionally extracts metadata, and then compiles this information into a retrievable index. +Build an INDEX for RAG (Retriever-Augmented Generation) applications from the provided file paths. +INDEX is a object storing the document chunks and their embeddings (and potentially other information). + +The function processes each file or document (depending on `chunker`), splits its content into chunks, embeds these chunks, +optionally extracts metadata, and then combines this information into a retrievable index. + +Define your own methods via `indexer` and its subcomponents (`chunker`, `embedder`, `tagger`). # Arguments -- `files_or_docs`: A vector of valid file paths OR string documents to be indexed (chunked and embedded). -- `reader`: A symbol indicating the type of input, can be either `:files` or `:docs`. Default is `:files`. -- `separators`: A list of strings used as separators for splitting the text in each file into chunks. Default is `[\\n\\n, ". ", "\\n"]`. -- `max_length`: The maximum length of each chunk (if possible with provided separators). Default is 256. -- `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs` (for `reader=:files`) -- `extras`: An optional vector of extra information to be stored with each chunk. Default is `nothing`. -- `extract_metadata`: A boolean flag indicating whether to extract metadata from each chunk (to build filter `tags` in the index). Default is `false`. - Metadata extraction incurs additional cost and requires `model_metadata` and `metadata_template` to be provided. +- `indexer::AbstractIndexBuilder`: The indexing logic to use. Default is `SimpleIndexer()`. +- `files_or_docs`: A vector of valid file paths OR string documents to be indexed (chunked and embedded). Specify which mode to use via `chunker`. - `verbose`: An Integer specifying the verbosity of the logs. Default is `1` (high-level logging). `0` is disabled. -- `metadata_template`: A symbol indicating the template to be used for metadata extraction. Default is `:RAGExtractMetadataShort`. -- `model_embedding`: The model to use for embedding. -- `model_metadata`: The model to use for metadata extraction. -- `api_kwargs`: Parameters to be provided to the API endpoint. Shared across all API calls. -- `embedding_kwargs`: Parameters to be provided to the `get_embedding` function. Useful to change the batch sizes (`target_batch_size_length`) or reduce asyncmap tasks (`ntasks`). -- `metadata_kwargs`: Parameters to be provided to the `get_metadata` function. +- `extras`: An optional vector of extra information to be stored with each chunk. Default is `nothing`. +- `index_id`: A unique identifier for the index. Default is a generated symbol. +- `chunker`: The chunker logic to use for splitting the documents. Default is `TextChunker()`. +- `chunker_kwargs`: Parameters to be provided to the `get_chunks` function. Useful to change the `separators` or `max_length`. + - `sources`: A vector of strings indicating the source of each chunk. Default is equal to `files_or_docs`. +- `embedder`: The embedder logic to use for embedding the chunks. Default is `BatchEmbedder()`. +- `embedder_kwargs`: Parameters to be provided to the `get_embeddings` function. Useful to change the `target_batch_size_length` or reduce asyncmap tasks `ntasks`. + - `model`: The model to use for embedding. Default is `PT.MODEL_EMBEDDING`. +- `tagger`: The tagger logic to use for extracting tags from the chunks. Default is `NoTagger()`, ie, skip tag extraction. There are also `PassthroughTagger` and `OpenTagger`. +- `tagger_kwargs`: Parameters to be provided to the `get_tags` function. + - `model`: The model to use for tags extraction. Default is `PT.MODEL_CHAT`. + - `template`: A template to be used for tags extraction. Default is `:RAGExtractMetadataShort`. + - `tags`: A vector of vectors of strings directly providing the tags for each chunk. Applicable for `tagger::PasstroughTagger`. +- `api_kwargs`: Parameters to be provided to the API endpoint. Shared across all API calls if provided. +- `cost_tracker`: A `Threads.Atomic{Float64}` object to track the total cost of the API calls. Useful to pass the total cost to the parent call. # Returns - `ChunkIndex`: An object containing the compiled index of chunks, embeddings, tags, vocabulary, and sources. -See also: `MultiIndex`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank`, `airag` +See also: `ChunkIndex`, `get_chunks`, `get_embeddings`, `get_tags`, `CandidateChunks`, `find_closest`, `find_tags`, `rerank`, `retrieve`, `generate!`, `airag` # Examples ```julia -# Assuming `test_files` is a vector of file paths -index = build_index(test_files; max_length=10, extract_metadata=true) +# Default is loading a vector of strings and chunking them (`TextChunker()`) +index = build_index(SimpleIndexer(), texts; chunker_kwargs = (; max_length=10)) -# Another example with metadata extraction and verbose output (`reader=:files` is implicit) -index = build_index(["file1.txt", "file2.txt"]; - separators=[". "], - extract_metadata=true, - verbose=true) +# Another example with tags extraction, splitting only sentences and verbose output +# Assuming `test_files` is a vector of file paths +indexer = SimpleIndexer(chunker=FileChunker(), tagger=OpenTagger()) +index = build_index(indexer, test_files; + chunker_kwargs(; separators=[". "]), verbose=true) ``` # Notes @@ -261,53 +403,47 @@ index = build_index(["file1.txt", "file2.txt"]; Some providers cannot handle large batch sizes (eg, Databricks). """ -function build_index(files_or_docs::Vector{<:AbstractString}; reader::Symbol = :files, - separators = ["\n\n", ". ", "\n"], max_length::Int = 256, - sources::Vector{<:AbstractString} = files_or_docs, +function build_index( + indexer::AbstractIndexBuilder, files_or_docs::Vector{<:AbstractString}; + verbose::Integer = 1, extras::Union{Nothing, AbstractVector} = nothing, - extract_metadata::Bool = false, verbose::Integer = 1, index_id = gensym("ChunkIndex"), - metadata_template::Symbol = :RAGExtractMetadataShort, - model_embedding::String = PT.MODEL_EMBEDDING, - model_metadata::String = PT.MODEL_CHAT, - embedding_kwargs::NamedTuple = NamedTuple(), - metadata_kwargs::NamedTuple = NamedTuple(), + chunker::AbstractChunker = indexer.chunker, + chunker_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = indexer.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = indexer.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cost_tracker = Threads.Atomic{Float64}(0.0)) ## Split into chunks - output_chunks, output_sources = get_chunks(files_or_docs; - reader, sources, separators, max_length) + chunks, sources = get_chunks(chunker, files_or_docs; + chunker_kwargs...) ## Embed chunks - embeddings = get_embeddings(output_chunks; + embeddings = get_embeddings(embedder, chunks; verbose = (verbose > 1), cost_tracker, - model = model_embedding, - api_kwargs, embedding_kwargs...) - - ## Extract metadata - tags, tags_vocab = if extract_metadata - output_metadata = get_metadata(output_chunks; - verbose = (verbose > 1), - cost_tracker, - model = model_metadata, - metadata_template, - api_kwargs, metadata_kwargs...) - # Requires SparseArrays.jl to be loaded - build_tags(output_metadata) - else - nothing, nothing - end - ## Create metadata tag array and associated vocabulary + api_kwargs, embedder_kwargs...) + + ## Extract tags + tags_extracted = get_tags(tagger, chunks; + verbose = (verbose > 1), + cost_tracker, + api_kwargs, tagger_kwargs...) + # Build the sparse matrix and the vocabulary + tags, tags_vocab = build_tags(tagger, tags_extracted) + (verbose > 0) && @info "Index built! (cost: \$$(round(cost_tracker[], digits=3)))" - index = ChunkIndex(; - id = index_id, - embeddings, - tags, tags_vocab, - chunks = output_chunks, - sources = output_sources, - extras) + index = ChunkIndex(; id = index_id, embeddings, tags, tags_vocab, + chunks, sources, extras) return index end + +# Default dispatch +const DEFAULT_INDEXER = SimpleIndexer() +function build_index(files_or_docs::Vector{<:AbstractString}; kwargs...) + build_index(DEFAULT_INDEXER, files_or_docs; kwargs...) +end \ No newline at end of file diff --git a/src/Experimental/RAGTools/rag_interface.jl b/src/Experimental/RAGTools/rag_interface.jl index 62cecf0a..7b8b2312 100644 --- a/src/Experimental/RAGTools/rag_interface.jl +++ b/src/Experimental/RAGTools/rag_interface.jl @@ -1,87 +1,181 @@ ############################ -### THIS IS WORK IN PROGRESS +### RAG Interface Write-up ############################ # This is the outline of the current RAG interface. # -### System Overview -# -# This system is designed for information retrieval and response generation, structured in three main phases: -# - Preparation, when you create an instance of `AbstractIndex` -# - Retrieval, when you surface the top most relevant chunks/items in the `index` -# - Generation, when you generate an answer based on the retrieved chunks -# -# The system is designed to be hackable and extensible at almost every entry point. -# You just need to define the corresponding concrete type (struct XYZ <: AbstractXYZ end) and the corresponding method. -# Then you pass the instance of this new type via kwargs, eg, `annotater=TrigramAnnotater()` -# -### RAG Diagram -# -# build_index -# signature: () -> AbstractChunkIndex -# flow: get_chunks -> get_embeddings -> get_tags -# -# airag: -# signature: () -> AIMessage or (AIMessage, RAGDetails) -# flow: retrieve -> generate -# -# retrieve: -# signature: () -> RAGContext -# flow: rephrase -> aiembed -> find_closest -> find_tags -> rerank -# -# generate: -# signature: () -> AIMessage or (AIMessage, RAGContext) -# flow: format_context -> aigenerate -> refine -# -### Deepdive -# -# Preparation Phase: -# - Begins with `build_index`, which creates a user-defined index type from an abstract chunk index using specified models and function strategies. -# - `get_chunks` then divides the indexed data into manageable pieces based on a chunking strategy. -# - `get_embeddings` generates embeddings for each chunk using an embedding strategy to facilitate similarity searches. -# - Finally, `get_metadata` extracts relevant metadata from each chunk using a metadata strategy, enabling metadata-based filtering. -# -# Generation/E2E Phase: -# - Starts with `airag`, initiating the response generation process using a custom index to potentially alter the high-level structure. -# - The `retrieve` step employs a retrieval strategy to fetch relevant data, which can be modified by a rephrase strategy for better query matching. -# - `aiembed` generates embeddings for the rephrased query, which are used in `find_closest` to identify the most relevant chunks using a similarity strategy. -# - Optional tag filtering (`aiextract + find_exact`) can be applied before candidates are re-ranked using a reranking strategy. -# - `format_context` constructs the context for response generation based on a context strategy, leading to the `aigenerate` step that produces the final answer. -# - The process concludes with `Refine`, applying a refine strategy for any final adjustments or re-evaluation. -# +## # System Overview +## +## This system is designed for information retrieval and response generation, structured in three main phases: +## - Preparation, when you create an instance of `AbstractIndex` +## - Retrieval, when you surface the top most relevant chunks/items in the `index` and return `AbstractRAGResult`, which contains the references to the chunks (`AbstractCandidateChunks`) +## - Generation, when you generate an answer based on the context built from the retrieved chunks, return either `AIMessage` or `AbstractRAGResult` +## +## The system is designed to be hackable and extensible at almost every entry point. +## If you want to customize the behavior of any step, you can do so by defining a new type and defining a new method for the step you're changing, eg, +## ```julia +## struct MyReranker <: AbstractReranker end +## RT.rerank(::MyReranker, index, candidates) = ... +## ``` +## And then you'd ask for the `retrive` step to use your custom `MyReranker`, eg, `retrieve(....; reranker = MyReranker())` (or customize the main dispatching `AbstractRetriever` struct). +## +## # RAG Diagram +## +## The main functions are: +## +## `build_index`: +## - signature: `(indexer::AbstractIndexBuilder, files_or_docs::Vector{<:AbstractString}) -> AbstractChunkIndex` +## - flow: `get_chunks` -> `get_embeddings` -> `get_tags` -> `build_tags` +## - dispatch types: `AbstractIndexBuilder`, `AbstractChunker`, `AbstractEmbedder`, `AbstractTagger` +## +## `airag`: +## - signature: `(cfg::AbstractRAGConfig, index::AbstractChunkIndex; question::AbstractString)` -> `AIMessage` or `AbstractRAGResult` +## - flow: `retrieve` -> `generate!` +## - dispatch types: `AbstractRAGConfig`, `AbstractRetriever`, `AbstractGenerator` +## +## `retrieve`: +## - signature: `(retriever::AbstractRetriever, index::AbstractChunkIndex, question::AbstractString) -> AbstractRAGResult` +## - flow: `rephrase` -> `get_embeddings` -> `find_closest` -> `get_tags` -> `find_tags` -> `rerank` +## - dispatch types: `AbstractRAGConfig`, `AbstractRephraser`, `AbstractEmbedder`, `AbstractSimilarityFinder`, `AbstractTagger`, `AbstractTagFilter`, `AbstractReranker` +## +## `generate!`: +## - signature: `(generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult)` -> `AIMessage` or `AbstractRAGResult` +## - flow: `build_context!` -> `answer!` -> `refine!` -> `postprocess!` +## - dispatch types: `AbstractGenerator`, `AbstractContextBuilder`, `AbstractAnswerer`, `AbstractRefiner`, `AbstractPostprocessor` +## +## To discover the currently available implementations, use `subtypes` function, eg, `subtypes(AbstractReranker)`. +## +## # Deepdive +## +## **Preparation Phase:** +## - Begins with `build_index`, which creates a user-defined index type from an abstract chunk index using specified dels and function strategies. +## - `get_chunks` then divides the indexed data into manageable pieces based on a chunking strategy. +## - `get_embeddings` generates embeddings for each chunk using an embedding strategy to facilitate similarity arches. +## - Finally, `get_tags` extracts relevant metadata from each chunk, enabling tag-based filtering (hybrid search index). If there are `tags` available, `build_tags` is called to build the corresponding sparse matrix for filtering with tags. + +## **Retrieval Phase:** +## - The `retrieve` step is intended to find the most relevant chunks in the `index`. +## - `rephrase` is called first, if we want to rephrase the query (methods like `HyDE` can improve retrieval quite a bit)! +## - `get_embeddings` generates embeddings for the original + rephrased query +## - `find_closest` looks up the most relevant candidates (`CandidateChunks`) using a similarity search strategy. +## - `get_tags` extracts the potential tags (can be provided as part of the `airag` call, eg, when we want to use only some small part of the indexed chunks) +## - `find_tags` filters the candidates to strictly match _at least one_ of the tags (if provided) +## - `rerank` is called to rerank the candidates based on the reranking strategy (ie, to improve the ordering of the chunks in context). + +## **Generation Phase:** +## - The `generate` step is intended to generate a response based on the retrieved chunks, provided via `AbstractRAGResult` (eg, `RAGResult`). +## - `build_context!` constructs the context for response generation based on a context strategy and applies the necessary formatting +## - `answer!` generates the response based on the context and the query +## - `refine!` is called to refine the response (optional, defaults to passthrough) +## - `postprocessing!` is available for any final touches to the response or to potentially save or format the results (eg, automatically save to the disk) + +## Note that all generation steps are mutating the `RAGResult` object. -### Types ############################ -### NOT READY!!! +### TYPES ############################ -# -# Defines three key types for RAG: ChunkIndex, MultiIndex, and CandidateChunks -# In addition, RAGContext is defined for debugging purposes +# Defines the main abstract types used in our RAG system. + +# ## Overarching + +# Dispatch type for airag +abstract type AbstractRAGConfig end + +# supertype for RAGDetails, return_type for retrieve and generate (and optionally airag) +abstract type AbstractRAGResult end # ## Preparation Stage +# Main supertype for all customizations of the indexing process +abstract type AbstractIndexingMethod end + +""" + AbstractIndexBuilder + +Abstract type for building an index with `build_index` (use to change the process / return type of `build_index`). + +# Required Fields +- `chunker::AbstractChunker`: the chunking method, dispatching `get_chunks` +- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` +- `tagger::AbstractTagger`: the tagging method, dispatching `get_tags` +""" +abstract type AbstractIndexBuilder <: AbstractIndexingMethod end + +# For get_chunks function +abstract type AbstractChunker <: AbstractIndexingMethod end +# For get_embeddings function +abstract type AbstractEmbedder <: AbstractIndexingMethod end +# For get_tags function +abstract type AbstractTagger <: AbstractIndexingMethod end + +### Index itself - return type of `build_index` abstract type AbstractDocumentIndex end + +""" + AbstractMultiIndex <: AbstractDocumentIndex + +Experimental abstract type for storing multiple document indexes. Not yet implemented. +""" abstract type AbstractMultiIndex <: AbstractDocumentIndex end + +""" + AbstractChunkIndex <: AbstractDocumentIndex + +Main abstract type for storing document chunks and their embeddings. It also stores tags and sources for each chunk. + +# Required Fields +- `id::Symbol`: unique identifier of each index (to ensure we're using the right index with `CandidateChunks`) +- `chunks::Vector{<:AbstractString}`: underlying document chunks / snippets +- `embeddings::Union{Nothing, Matrix{<:Real}}`: for semantic search +- `tags::Union{Nothing, AbstractMatrix{<:Bool}}`: for exact search, filtering, etc. This is often a sparse matrix indicating which chunks have the given `tag` (see `tag_vocab` for the position lookup) +- `tags_vocab::Union{Nothing, Vector{<:AbstractString}}`: vocabulary for the `tags` matrix (each column in `tags` is one item in `tags_vocab` and rows are the chunks) +- `sources::Vector{<:AbstractString}`: sources of the chunks +- `extras::Union{Nothing, AbstractVector}`: additional data, eg, metadata, source code, etc. +""" abstract type AbstractChunkIndex <: AbstractDocumentIndex end +# ## Retrieval stage + +""" + AbstractCandidateChunks + +Abstract type for storing candidate chunks, ie, references to items in a `AbstractChunkIndex`. + +Return type from `find_closest` and `find_tags` functions. + +# Required Fields +- `index_id::Symbol`: the id of the index from which the candidates are drawn +- `positions::Vector{Int}`: the positions of the candidates in the index +- `scores::Vector{Float32}`: the similarity scores of the candidates from the query (higher is better) +""" abstract type AbstractCandidateChunks end -# supertype for RAGDetails -abstract type AbstractRAGResult end +# Main supertype for retrieval customizations +abstract type AbstractRetrievalMethod end -# ## Retrieval stage # Main dispatch type for `retrieve` -abstract type AbstractRetrievalMethod end +""" + AbstractRetriever <: AbstractRetrievalMethod + +Abstract type for retrieving chunks from an index with `retrieve` (use to change the process / return type of `retrieve`). + +# Required Fields +- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` +- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` +- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` +- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` +""" +abstract type AbstractRetriever <: AbstractRetrievalMethod end # Main dispatch type for `rephrase` abstract type AbstractRephraser <: AbstractRetrievalMethod end # Main dispatch type for `find_closest` -abstract type AbstractSimilaritySearch <: AbstractRetrievalMethod end +abstract type AbstractSimilarityFinder <: AbstractRetrievalMethod end # Main dispatch type for `find_tags` -abstract type AbstractTagMatch <: AbstractRetrievalMethod end +abstract type AbstractTagFilter <: AbstractRetrievalMethod end # Main dispatch type for `rerank` abstract type AbstractReranker <: AbstractRetrievalMethod end @@ -89,16 +183,66 @@ abstract type AbstractReranker <: AbstractRetrievalMethod end # ## Generation stage abstract type AbstractGenerationMethod end -# Main dispatch type for: `format_context` -abstract type AbstractContextFormater <: AbstractGenerationMethod end +# Main dispatch type for: `generate!` +""" + AbstractGenerator <: AbstractGenerationMethod + +Abstract type for generating an answer with `generate!` (use to change the process / return type of `generate`). + +# Required Fields +- `contexter::AbstractContextBuilder`: the context building method, dispatching `build_context! +- `answerer::AbstractAnswerer`: the answer generation method, dispatching `answer!` +- `refiner::AbstractRefiner`: the answer refining method, dispatching `refine!` +- `postprocessor::AbstractPostprocessor`: the postprocessing method, dispatching `postprocess!` +""" +abstract type AbstractGenerator <: AbstractGenerationMethod end + +# Main dispatch type for: `build_context!` +abstract type AbstractContextBuilder <: AbstractGenerationMethod end + +# Main dispatch type for: `answer!` +abstract type AbstractAnswerer <: AbstractGenerationMethod end -# Main dispatch type for: `refine` +# Main dispatch type for: `refine!` abstract type AbstractRefiner <: AbstractGenerationMethod end +# Main dispatch type for: `postprocess!` +abstract type AbstractPostprocessor <: AbstractGenerationMethod end + # ## Exploration/Display stage # Supertype for annotaters, dispatch for `annotate_support` abstract type AbstractAnnotater end abstract type AbstractAnnotatedNode end -abstract type AbstractAnnotationStyler end \ No newline at end of file +abstract type AbstractAnnotationStyler end + +############################ +### FUNCTIONS +############################ + +# ## Main Functions + +# Builds the index from provided data, dispatch via `indexer::AbstractIndexer`. +function build_index end +function get_chunks end +function get_embeddings end +function get_tags end +# Sub-routing of get_tags, extended in ext/RAGToolsExperimentalExt.jl +"Builds a matrix of tags and a vocabulary list. REQUIRES SparseArrays and LinearAlgebra packages to be loaded!!" +function build_tags end + +# Retrieval stage -> ultimately returns `RAGResult` +function retrieve end +function rephrase end +function find_closest end +function find_tags end +function rerank end + +# Generation stage -> returns mutated `RAGResult` +function generate! end +function build_context! end +function build_context end +function answer! end +function refine! end +function postprocess! end \ No newline at end of file diff --git a/src/Experimental/RAGTools/retrieval.jl b/src/Experimental/RAGTools/retrieval.jl index 06ef38c5..924b3e94 100644 --- a/src/Experimental/RAGTools/retrieval.jl +++ b/src/Experimental/RAGTools/retrieval.jl @@ -1,37 +1,195 @@ +### Types for Retrieval + +""" + NoRephraser <: AbstractRephraser + +No-op implementation for `rephrase`, which simply passes the question through. +""" +struct NoRephraser <: AbstractRephraser end + +""" + SimpleRephraser <: AbstractRephraser + +Rephraser implemented using the provided AI Template (eg, `...`) and standard chat model. +""" +struct SimpleRephraser <: AbstractRephraser end + +""" + HyDERephraser <: AbstractRephraser + +Rephraser implemented using the provided AI Template (eg, `...`) and standard chat model. + +It uses a prompt-based rephrasing method called HyDE (Hypothetical Document Embedding), where instead of looking for an embedding of the question, +we look for the documents most similar to a synthetic passage that _would be_ a good answer to our question. + +Reference: [Arxiv paper](https://arxiv.org/abs/2212.10496). +""" +struct HyDERephraser <: AbstractRephraser end + +""" + CosineSimilarity <: AbstractSimilarityFinder + +Finds the closest chunks to a query embedding by measuring the cosine similarity between the query and the chunks' embeddings. +""" +struct CosineSimilarity <: AbstractSimilarityFinder end + """ - find_closest(emb::AbstractMatrix{<:Real}, + NoTagFilter <: AbstractTagFilter + + +No-op implementation for `find_tags`, which simply returns all chunks. +""" +struct NoTagFilter <: AbstractTagFilter end + +""" + AnyTagFilter <: AbstractTagFilter + +Finds the chunks that have ANY OF the specified tag(s). +""" +struct AnyTagFilter <: AbstractTagFilter end + +### Functions +function rephrase(rephraser::AbstractRephraser, question::AbstractString; kwargs...) + throw(ArgumentError("Not implemented yet for type $(typeof(rephraser))")) +end + +""" + rephrase(rephraser::NoRephraser, question::AbstractString; kwargs...) + +No-op, simple passthrough. +""" +function rephrase(rephraser::NoRephraser, question::AbstractString; kwargs...) + return [question] +end + +""" + rephrase(rephraser::SimpleRephraser, question::AbstractString; + verbose::Bool = true, + model::String = PT.MODEL_CHAT, template::Symbol = :RAGQueryOptimizer, + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) + +Rephrases the `question` using the provided rephraser `template`. + +Returns both the original and the rephrased question. + +# Arguments +- `rephraser`: Type that dictates the logic of rephrasing step. +- `question`: The question to be rephrased. +- `model`: The model to use for rephrasing. Default is `PT.MODEL_CHAT`. +- `template`: The rephrasing template to use. Default is `:RAGQueryOptimizer`. Find more with `aitemplates("rephrase")`. +- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `true`. +""" +function rephrase(rephraser::SimpleRephraser, question::AbstractString; + verbose::Bool = true, + model::String = PT.MODEL_CHAT, template::Symbol = :RAGQueryOptimizer, + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) + ## checks + placeholders = only(aitemplates(template)).variables # only one template should be found + @assert (:query in placeholders) "Provided RAG Template $(template) is not suitable. It must have a placeholder: `query`." + + msg = aigenerate(template; query = question, verbose, model, kwargs...) + Threads.atomic_add!(cost_tracker, msg.cost) + new_question = strip(msg.content) + return [question, new_question] +end + +""" + rephrase(rephraser::SimpleRephraser, question::AbstractString; + verbose::Bool = true, + model::String = PT.MODEL_CHAT, template::Symbol = :RAGQueryHyDE, + cost_tracker = Threads.Atomic{Float64}(0.0)) + +Rephrases the `question` using the provided rephraser `template = RAGQueryHyDE`. + +Special flavor of rephrasing using HyDE (Hypothetical Document Embedding) method, +which aims to find the documents most similar to a synthetic passage that _would be_ a good answer to our question. + +Returns both the original and the rephrased question. + +# Arguments +- `rephraser`: Type that dictates the logic of rephrasing step. +- `question`: The question to be rephrased. +- `model`: The model to use for rephrasing. Default is `PT.MODEL_CHAT`. +- `template`: The rephrasing template to use. Default is `:RAGQueryHyDE`. Find more with `aitemplates("rephrase")`. +- `verbose`: A boolean flag indicating whether to print verbose logging. Default is `true`. +""" +function rephrase(rephraser::HyDERephraser, question::AbstractString; + verbose::Bool = true, + model::String = PT.MODEL_CHAT, template::Symbol = :RAGQueryHyDE, + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) + rephrase(SimpleRephraser(), question; verbose, model, template, cost_tracker, kwargs...) +end + +# General fallback +function find_closest( + finder::AbstractSimilarityFinder, emb::AbstractMatrix{<:Real}, + query_emb::AbstractVector{<:Real}; kwargs...) + throw(ArgumentError("Not implemented yet for type $(typeof(finder))")) +end + +""" + find_closest(finder::CosineSimilarity, emb::AbstractMatrix{<:Real}, query_emb::AbstractVector{<:Real}; - top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0) + top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) -Finds the indices of chunks (represented by embeddings in `emb`) that are closest (cosine similarity) to query embedding (`query_emb`). +Finds the indices of chunks (represented by embeddings in `emb`) that are closest (in cosine similarity for `CosineSimilarity()`) to query embedding (`query_emb`). + +`finder` is the logic used for the similarity search. Default is `CosineSimilarity`. If `minimum_similarity` is provided, only indices with similarity greater than or equal to it are returned. Similarity can be between -1 and 1 (-1 = completely opposite, 1 = exactly the same). Returns only `top_k` closest indices. """ -function find_closest(emb::AbstractMatrix{<:Real}, +function find_closest( + finder::CosineSimilarity, emb::AbstractMatrix{<:Real}, query_emb::AbstractVector{<:Real}; - top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0) + top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0, kwargs...) # emb is an embedding matrix where the first dimension is the embedding dimension - distances = query_emb' * emb |> vec - positions = distances |> sortperm |> reverse |> x -> first(x, top_k) + scores = query_emb' * emb |> vec + positions = scores |> sortperm |> reverse |> x -> first(x, top_k) if minimum_similarity > -1.0 - mask = distances[positions] .>= minimum_similarity + mask = scores[positions] .>= minimum_similarity positions = positions[mask] end - return positions, distances[positions] + return positions, scores[positions] end -function find_closest(index::AbstractChunkIndex, + +""" + find_closest( + finder::AbstractSimilarityFinder, index::AbstractChunkIndex, query_emb::AbstractVector{<:Real}; - top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0) - isnothing(embeddings(index)) && CandidateChunks(; index_id = index.id) - positions, distances = find_closest(embeddings(index), + top_k::Int = 100, kwargs...) + +Finds the indices of chunks (represented by embeddings in `index`) that are closest to query embedding (`query_emb`). + +Returns only `top_k` closest indices. +""" +function find_closest( + finder::AbstractSimilarityFinder, index::AbstractChunkIndex, + query_emb::AbstractVector{<:Real}; + top_k::Int = 100, kwargs...) + isnothing(embeddings(index)) && return CandidateChunks(; index_id = index.id) + positions, scores = find_closest(finder, embeddings(index), query_emb; - top_k, - minimum_similarity) - return CandidateChunks(index.id, positions, Float32.(distances)) + top_k, kwargs...) + return CandidateChunks(index.id, positions, Float32.(scores)) end + +# Dispatch to find scores for multiple embeddings +function find_closest( + finder::AbstractSimilarityFinder, index::AbstractChunkIndex, + query_emb::AbstractMatrix{<:Real}; + top_k::Int = 100, kwargs...) + isnothing(embeddings(index)) && CandidateChunks(; index_id = index.id) + ## reduce top_k since we have more than one query + top_k_ = top_k ÷ size(query_emb, 2) + ## simply vcat together (gets sorted from the highest similarity to the lowest) + mapreduce( + c -> find_closest(finder, index, c; top_k = top_k_, kwargs...), vcat, eachcol(query_emb)) +end + +## TODO: Implement for MultiIndex ## function find_closest(index::AbstractMultiIndex, ## query_emb::AbstractVector{<:Real}; ## top_k::Int = 100, minimum_similarity::AbstractFloat = -1.0) @@ -52,8 +210,25 @@ end ## all_distances[top_k_order]) ## end -function find_tags(index::AbstractChunkIndex, - tag::Union{AbstractString, Regex}) +### TAG Filtering + +function find_tags(::AbstractTagFilter, index::AbstractChunkIndex, + tag::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: + Union{AbstractString, Regex}} + throw(ArgumentError("Not implemented yet for type $(typeof(filter))")) +end + +""" + find_tags(method::AnyTagFilter, index::AbstractChunkIndex, + tag::Union{AbstractString, Regex}; kwargs...) + + find_tags(method::AnyTagFilter, index::AbstractChunkIndex, + tags::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}} + +Finds the indices of chunks (represented by tags in `index`) that have ANY OF the specified `tag` or `tags`. +""" +function find_tags(method::AnyTagFilter, index::AbstractChunkIndex, + tag::Union{AbstractString, Regex}; kwargs...) isnothing(tags(index)) && CandidateChunks(; index_id = index.id) tag_idx = if tag isa AbstractString findall(tags_vocab(index) .== tag) @@ -65,70 +240,108 @@ function find_tags(index::AbstractChunkIndex, x -> getindex.(x, 1) |> unique return CandidateChunks(index.id, match_row_idx, ones(Float32, length(match_row_idx))) end -function find_tags(index::AbstractChunkIndex, - tags::Vector{<:AbstractString}) - pos = [find_tags(index, tag).positions for tag in tags] |> + +# Method for multiple tags +function find_tags(method::AnyTagFilter, index::AbstractChunkIndex, + tags::Vector{T}; kwargs...) where {T <: Union{AbstractString, Regex}} + pos = [find_tags(method, index, tag).positions for tag in tags] |> Base.Splat(vcat) |> unique |> x -> convert(Vector{Int}, x) return CandidateChunks(index.id, pos, ones(Float32, length(pos))) end -# Assuming the rerank and strategy definitions are in the Main module or relevant module -abstract type RerankingStrategy end - -struct Passthrough <: RerankingStrategy end -struct CohereRerank <: RerankingStrategy end +""" + find_tags(method::NoTagFilter, index::AbstractChunkIndex, + tags; kwargs...) -function rerank(strategy::Passthrough, - index, - question, - candidate_chunks; - top_n::Integer = length(candidate_chunks), - kwargs...) - # Since this is a Passthrough strategy, it returns the candidate_chunks unchanged - return first(candidate_chunks, top_n) +Returns all chunks in the index, ie, no filtering. +""" +function find_tags(method::NoTagFilter, index::AbstractChunkIndex, + tags::Union{T, AbstractVector{<:T}}; kwargs...) where {T <: + Union{ + AbstractString, Regex}} + return CandidateChunks( + index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks))) +end +function find_tags(method::NoTagFilter, index::AbstractChunkIndex, + tags::Nothing; kwargs...) + return CandidateChunks( + index.id, collect(1:length(index.chunks)), zeros(Float32, length(index.chunks))) end -function rerank(strategy::CohereRerank, - index::AbstractDocumentIndex, args...; kwargs...) +### Reranking + +""" + NoReranker <: AbstractReranker + +No-op implementation for `rerank`, which simply passes the candidate chunks through. +""" +struct NoReranker <: AbstractReranker end + +""" + CohereReranker <: AbstractReranker + +Rerank strategy using the Cohere Rerank API. Requires an API key. +""" +struct CohereReranker <: AbstractReranker end + +function rerank(reranker::AbstractReranker, + index::AbstractDocumentIndex, question::AbstractString, candidates::AbstractCandidateChunks; kwargs...) throw(ArgumentError("Not implemented yet")) end +function rerank(reranker::NoReranker, + index::AbstractChunkIndex, + question::AbstractString, + candidates::AbstractCandidateChunks; + top_n::Integer = length(candidates), + kwargs...) + # Since this is almost a passthrough strategy, it returns the candidate_chunks unchanged + # but it truncates to `top_n` if necessary + return first(candidates, top_n) +end + """ - rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, - candidate_chunks; + rerank( + reranker::CohereReranker, index::AbstractChunkIndex, question::AbstractString, + candidates::AbstractCandidateChunks; verbose::Bool = false, api_key::AbstractString = PT.COHERE_API_KEY, - top_n::Integer = length(candidate_chunks.distances), + top_n::Integer = length(candidates.scores), model::AbstractString = "rerank-english-v2.0", return_documents::Bool = false, + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) + Re-ranks a list of candidate chunks using the Cohere Rerank API. See https://cohere.com/rerank for more details. # Arguments -- `query`: The query to be used for the search. -- `documents`: A vector of documents to be reranked. - The total max chunks (`length of documents * max_chunks_per_doc`) must be less than 10000. We recommend less than 1000 documents for optimal performance. +- `reranker`: Using Cohere API +- `index`: The index that holds the underlying chunks to be re-ranked. +- `question`: The query to be used for the search. +- `candidates`: The candidate chunks to be re-ranked. - `top_n`: The number of most relevant documents to return. Default is `length(documents)`. - `model`: The model to use for reranking. Default is `rerank-english-v2.0`. - `return_documents`: A boolean flag indicating whether to return the reranked documents in the response. Default is `false`. -- `max_chunks_per_doc`: The maximum number of chunks to use per document. Default is `10`. - `verbose`: A boolean flag indicating whether to print verbose logging. Default is `false`. +- `cost_tracker`: An atomic counter to track the cost of the retrieval. Default is `Threads.Atomic{Float64}(0.0)`. Not currently tracked (cost unclear). """ -function rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, - candidate_chunks; +function rerank( + reranker::CohereReranker, index::AbstractChunkIndex, question::AbstractString, + candidates::AbstractCandidateChunks; verbose::Bool = false, api_key::AbstractString = PT.COHERE_API_KEY, - top_n::Integer = length(candidate_chunks.distances), + top_n::Integer = length(candidates.scores), model::AbstractString = "rerank-english-v2.0", return_documents::Bool = false, + cost_tracker = Threads.Atomic{Float64}(0.0), kwargs...) @assert top_n>0 "top_n must be a positive integer." - @assert index.id==candidate_chunks.index_id "The index id of the index and candidate_chunks must match." + @assert index.id==candidates.index_id "The index id of the index and `candidates` must match." ## Call the API - documents = index[candidate_chunks, :chunks] + documents = index[candidates, :chunks] verbose && @info "Calling Cohere Rerank API with $(length(documents)) candidate chunks..." r = cohere_api(; @@ -143,11 +356,11 @@ function rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, ## Unwrap re-ranked positions positions = Vector{Int}(undef, length(r.response[:results])) - distances = Vector{Float32}(undef, length(r.response[:results])) + scores = Vector{Float32}(undef, length(r.response[:results])) for i in eachindex(r.response[:results]) doc = r.response[:results][i] - positions[i] = candidate_chunks.positions[doc[:index] + 1] - distances[i] = doc[:relevance_score] + positions[i] = candidates.positions[doc[:index] + 1] + scores[i] = doc[:relevance_score] end ## Check the cost @@ -161,5 +374,250 @@ function rerank(strategy::CohereRerank, index::AbstractChunkIndex, question, end verbose && @info "Reranking done. $search_units_str" - return CandidateChunks(index.id, positions, distances) + return CandidateChunks(index.id, positions, scores) +end + +### Overall types for `retrieve` +""" + SimpleRetriever <: AbstractRetriever + +Default implementation for `retrieve`. It does a simple similarity search via `CosineSimilarity` and returns the results. + +Make sure to use consistent `embedder` and `tagger` with the Preparation Stage (`build_index`)! + +# Fields +- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` - uses `NoRephraser` +- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` (see Preparation Stage for more details) - uses `BatchEmbedder` +- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` - uses `CosineSimilarity` +- `tagger::AbstractTagger`: the tag generating method, dispatching `get_tags` (see Preparation Stage for more details) - uses `NoTagger` +- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` - uses `NoTagFilter` +- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` - uses `NoReranker` +""" +@kwdef mutable struct SimpleRetriever <: AbstractRetriever + rephraser::AbstractRephraser = NoRephraser() + embedder::AbstractEmbedder = BatchEmbedder() + finder::AbstractSimilarityFinder = CosineSimilarity() + tagger::AbstractTagger = NoTagger() + filter::AbstractTagFilter = NoTagFilter() + reranker::AbstractReranker = NoReranker() +end + +""" + AdvancedRetriever <: AbstractRetriever + +Dispatch for `retrieve` with advanced retrieval methods to improve result quality. +Compared to SimpleRetriever, it adds rephrasing the query and reranking the results. + +# Fields +- `rephraser::AbstractRephraser`: the rephrasing method, dispatching `rephrase` - uses `HyDERephraser` +- `embedder::AbstractEmbedder`: the embedding method, dispatching `get_embeddings` (see Preparation Stage for more details) - uses `BatchEmbedder` +- `finder::AbstractSimilarityFinder`: the similarity search method, dispatching `find_closest` - uses `CosineSimilarity` +- `tagger::AbstractTagger`: the tag generating method, dispatching `get_tags` (see Preparation Stage for more details) - uses `NoTagger` +- `filter::AbstractTagFilter`: the tag matching method, dispatching `find_tags` - uses `NoTagFilter` +- `reranker::AbstractReranker`: the reranking method, dispatching `rerank` - uses `CohereReranker` +""" +@kwdef mutable struct AdvancedRetriever <: AbstractRetriever + rephraser::AbstractRephraser = HyDERephraser() + embedder::AbstractEmbedder = BatchEmbedder() + finder::AbstractSimilarityFinder = CosineSimilarity() + tagger::AbstractTagger = NoTagger() + filter::AbstractTagFilter = NoTagFilter() + reranker::AbstractReranker = CohereReranker() +end + +""" + retrieve(retriever::AbstractRetriever, + index::AbstractChunkIndex, + question::AbstractString; + verbose::Integer = 1, + top_k::Integer = 100, + top_n::Integer = 5, + api_kwargs::NamedTuple = NamedTuple(), + rephraser::AbstractRephraser = retriever.rephraser, + rephraser_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = retriever.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + finder::AbstractSimilarityFinder = retriever.finder, + finder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = retriever.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + filter::AbstractTagFilter = retriever.filter, + filter_kwargs::NamedTuple = NamedTuple(), + reranker::AbstractReranker = retriever.reranker, + reranker_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + +Retrieves the most relevant chunks from the index for the given question and returns them in the `RAGResult` object. + +This is the main entry point for the retrieval stage of the RAG pipeline. It is often followed by `generate!` step. + +Notes: +- The default flow is `build_context!` -> `answer!` -> `refine!` -> `postprocess!`. + +The arguments correspond to the steps of the retrieval process (rephrasing, embedding, finding similar docs, tagging, filtering by tags, reranking). +You can customize each step by providing a new custom type that dispatches the corresponding function, + eg, create your own type `struct MyReranker<:AbstractReranker end` and define the custom method for it `rerank(::MyReranker,...) = ...`. + +Note: Discover available retrieval sub-types for each step with `subtypes(AbstractRephraser)` and similar for other abstract types. + +If you're using locally-hosted models, you can pass the `api_kwargs` with the `url` field set to the model's URL and make sure to provide corresponding + `model` kwargs to `rephraser`, `embedder`, and `tagger` to use the custom models (they make AI calls). + +# Arguments +- `retriever`: The retrieval method to use. Default is `SimpleRetriever` but could be `AdvancedRetriever` for more advanced retrieval. +- `index`: The index that holds the chunks and sources to be retrieved from. +- `question`: The question to be used for the retrieval. +- `verbose`: If `>0`, it prints out verbose logging. Default is `1`. If you set it to `2`, it will print out logs for each sub-function. +- `top_k`: The TOTAL number of closest chunks to return from `find_closest`. Default is `100`. + If there are multiple rephrased questions, the number of chunks per each item will be `top_k ÷ number_of_rephrased_questions`. +- `top_n`: The TOTAL number of most relevant chunks to return for the context (from `rerank` step). Default is `5`. +- `api_kwargs`: Additional keyword arguments to be passed to the API calls (shared by all `ai*` calls). +- `rephraser`: Transform the question into one or more questions. Default is `retriever.rephraser`. +- `rephraser_kwargs`: Additional keyword arguments to be passed to the rephraser. + - `model`: The model to use for rephrasing. Default is `PT.MODEL_CHAT`. + - `template`: The rephrasing template to use. Default is `:RAGQueryOptimizer` or `:RAGQueryHyDE` (depending on the `rephraser` selected). +- `embedder`: The embedding method to use. Default is `retriever.embedder`. +- `embedder_kwargs`: Additional keyword arguments to be passed to the embedder. +- `finder`: The similarity search method to use. Default is `retriever.finder`, often `CosineSimilarity`. +- `finder_kwargs`: Additional keyword arguments to be passed to the similarity finder. +- `tagger`: The tag generating method to use. Default is `retriever.tagger`. +- `tagger_kwargs`: Additional keyword arguments to be passed to the tagger. Noteworthy arguments: + - `tags`: Directly provide the tags to use for filtering (can be String, Regex, or Vector{String}). Useful for `tagger = PassthroughTagger`. +- `filter`: The tag matching method to use. Default is `retriever.filter`. +- `filter_kwargs`: Additional keyword arguments to be passed to the tag filter. +- `reranker`: The reranking method to use. Default is `retriever.reranker`. +- `reranker_kwargs`: Additional keyword arguments to be passed to the reranker. + - `model`: The model to use for reranking. Default is `rerank-english-v2.0` if you use `reranker = CohereReranker()`. +- `cost_tracker`: An atomic counter to track the cost of the retrieval. Default is `Threads.Atomic{Float64}(0.0)`. + +See also: `SimpleRetriever`, `AdvancedRetriever`, `build_index`, `rephrase`, `get_embeddings`, `find_closest`, `get_tags`, `find_tags`, `rerank`, `RAGResult`. + +# Examples + +Find the 5 most relevant chunks from the index for the given question. +```julia +# assumes you have an existing index `index` +retriever = SimpleRetriever() + +result = retrieve(retriever, + index, + "What is the capital of France?", + top_n = 5) + +# or use the default retriever (same as above) +result = retrieve(retriever, + index, + "What is the capital of France?", + top_n = 5) +``` + +Apply more advanced retrieval with question rephrasing and reranking (requires `COHERE_API_KEY`). +We will obtain top 100 chunks from embeddings (`top_k`) and top 5 chunks from reranking (`top_n`). + +```julia +retriever = AdvancedRetriever() + +result = retrieve(retriever, index, question; top_k=100, top_n=5) +``` + +You can use the `retriever` to customize your retrieval strategy or directly change the strategy types in the `retrieve` kwargs! + +Example of using locally-hosted model hosted on `localhost:8080`: +```julia +retriever = SimpleRetriever() +result = retrieve(retriever, index, question; + rephraser_kwargs = (; model = "custom"), + embedder_kwargs = (; model = "custom"), + tagger_kwargs = (; model = "custom"), api_kwargs = (; + url = "http://localhost:8080")) +``` +""" +function retrieve(retriever::AbstractRetriever, + index::AbstractChunkIndex, + question::AbstractString; + verbose::Integer = 1, + top_k::Integer = 100, + top_n::Integer = 5, + api_kwargs::NamedTuple = NamedTuple(), + rephraser::AbstractRephraser = retriever.rephraser, + rephraser_kwargs::NamedTuple = NamedTuple(), + embedder::AbstractEmbedder = retriever.embedder, + embedder_kwargs::NamedTuple = NamedTuple(), + finder::AbstractSimilarityFinder = retriever.finder, + finder_kwargs::NamedTuple = NamedTuple(), + tagger::AbstractTagger = retriever.tagger, + tagger_kwargs::NamedTuple = NamedTuple(), + filter::AbstractTagFilter = retriever.filter, + filter_kwargs::NamedTuple = NamedTuple(), + reranker::AbstractReranker = retriever.reranker, + reranker_kwargs::NamedTuple = NamedTuple(), + cost_tracker = Threads.Atomic{Float64}(0.0), + kwargs...) + ## Rephrase into one or more questions + rephraser_kwargs_ = isempty(api_kwargs) ? rephraser_kwargs : + merge(rephraser_kwargs, (; api_kwargs)) + rephrased_questions = rephrase( + rephraser, question; verbose = (verbose > 1), cost_tracker, rephraser_kwargs_...) + + ## Embed one or more rephrased questions + embedder_kwargs_ = isempty(api_kwargs) ? embedder_kwargs : + merge(embedder_kwargs, (; api_kwargs)) + embeddings = get_embeddings(embedder, rephrased_questions; + verbose = (verbose > 1), cost_tracker, embedder_kwargs_...) + + finder_kwargs_ = isempty(api_kwargs) ? finder_kwargs : + merge(finder_kwargs, (; api_kwargs)) + emb_candidates = find_closest(finder, index, embeddings; + verbose = (verbose > 1), top_k, finder_kwargs_...) + + ## Tagging - if you provide them explicitly, use tagger `PassthroughTagger` and `tagger_kwargs = (;tags = ...)` + tagger_kwargs_ = isempty(api_kwargs) ? tagger_kwargs : + merge(tagger_kwargs, (; api_kwargs)) + tags = get_tags(tagger, rephrased_questions; verbose = (verbose > 1), + cost_tracker, tagger_kwargs_...) + + filter_kwargs_ = isempty(api_kwargs) ? filter_kwargs : + merge(filter_kwargs, (; api_kwargs)) + tag_candidates = find_tags( + filter, index, tags; verbose = (verbose > 1), filter_kwargs_...) + + ## Combine the two sets of candidates, looks for intersection (hard filter)! + filtered_candidates = isnothing(tag_candidates) ? emb_candidates : + (emb_candidates & tag_candidates) + ## TODO: Future implementation should be to apply tag filtering BEFORE the find_closest, + ## but that requires implementing `view(::Index,...)` to provide only a subset of the embeddings to the subsequent functionality. + ## Also, find_closest is so fast & cheap that it doesn't matter at current scale/maturity of the use cases + + ## Reranking + reranker_kwargs_ = isempty(api_kwargs) ? reranker_kwargs : + merge(reranker_kwargs, (; api_kwargs)) + reranked_candidates = rerank(reranker, index, question, filtered_candidates; + top_n, verbose = (verbose > 1), cost_tracker, reranker_kwargs_...) + + verbose > 0 && + @info "Retrieval done. Identified $(length(reranked_candidates.positions)) chunks, total cost: \$$(cost_tracker[])." + + ## Return + result = RAGResult(; + question, + answer = nothing, + rephrased_questions, + final_answer = nothing, + context = chunks(index)[reranked_candidates.positions], + sources = sources(index)[reranked_candidates.positions], + emb_candidates, + tag_candidates, + filtered_candidates, + reranked_candidates) + + return result end + +# Set default behavior +DEFAULT_RETRIEVER = SimpleRetriever() +function retrieve(index::AbstractChunkIndex, question::AbstractString; + kwargs...) + return retrieve(DEFAULT_RETRIEVER, index, question; + kwargs...) +end \ No newline at end of file diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index 2cdbc16d..2b4b4959 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -68,7 +68,10 @@ function Base.vcat(i1::ChunkIndex, i2::ChunkIndex) sources = vcat(i1.sources, i2.sources)) end -"Composite index that stores multiple ChunkIndex objects and their embeddings" +## TODO: implement a view(), make sure to recover the output positions correctly +## TODO: fields: parent, positions + +"Composite index that stores multiple ChunkIndex objects and their embeddings. It's not yet fully implemented." @kwdef struct MultiIndex <: AbstractMultiIndex id::Symbol = gensym("MultiIndex") indexes::Vector{<:AbstractChunkIndex} @@ -90,22 +93,83 @@ function Base.var"=="(i1::MultiIndex, i2::MultiIndex) return true end +""" + CandidateChunks + +A struct for storing references to chunks in the given index (identified by `index_id`) called `positions` and `scores` holding the strength of similarity (=1 is the highest, most similar). +It's the result of the retrieval stage of RAG. + +# Fields +- `index_id::Symbol`: the id of the index from which the candidates are drawn +- `positions::Vector{Int}`: the positions of the candidates in the index (ie, `5` refers to the 5th chunk in the index - `chunks(index)[5]`) +- `scores::Vector{Float32}`: the similarity scores of the candidates from the query (higher is better) +""" @kwdef struct CandidateChunks{TP <: Union{Integer, AbstractCandidateChunks}, TD <: Real} <: AbstractCandidateChunks index_id::Symbol ## if TP is Int, then positions are indices into the index ## if TP is CandidateChunks, then positions are indices into the positions of the child index in MultiIndex positions::Vector{TP} = Int[] - distances::Vector{TD} = Float32[] + scores::Vector{TD} = Float32[] end +## TODO: disabled nested CandidateChunks for now +## TODO: create MultiCandidateChunks for use with MultiIndex, they will have extra field subindex_id + Base.length(cc::CandidateChunks) = length(cc.positions) function Base.first(cc::CandidateChunks, k::Integer) - CandidateChunks(cc.index_id, first(cc.positions, k), first(cc.distances, k)) + sorted_idxs = sortperm(cc.scores, rev = true) |> x -> first(x, k) + CandidateChunks(cc.index_id, cc.positions[sorted_idxs], cc.scores[sorted_idxs]) +end +function Base.copy(cc::CandidateChunks{TP, TD}) where {TP <: Integer, TD <: Real} + CandidateChunks{TP, TD}(cc.index_id, copy(cc.positions), copy(cc.scores)) +end +function Base.isempty(cc::CandidateChunks) + isempty(cc.positions) +end +function Base.var"=="(cc1::CandidateChunks, cc2::CandidateChunks) + all( + getfield(cc1, f) == getfield(cc2, f) for f in fieldnames(CandidateChunks)) +end + +# join and sort two candidate chunks +function Base.vcat(cc1::AbstractCandidateChunks, cc2::AbstractCandidateChunks) + throw(ArgumentError("Not implemented for type $(typeof(cc1)) and $(typeof(cc2))")) end -# combine/intersect two candidate chunks. average the score if available +function Base.vcat(cc1::CandidateChunks{TP1, TD1}, + cc2::CandidateChunks{TP2, TD2}) where { + TP1 <: Integer, TP2 <: Integer, TD1 <: Real, TD2 <: Real} + ## Check validity + cc1.index_id != cc2.index_id && + throw(ArgumentError("Index ids must match (provided: $(cc1.index_id) and $(cc2.index_id))")) + + positions = vcat(cc1.positions, cc2.positions) + # operates on maximum similarity principle, ie, take the max similarity + scores = if !isempty(cc1.scores) && !isempty(cc2.scores) + vcat(cc1.scores, cc2.scores) + else + TD1[] + end + + if !isempty(scores) + ## Get sorted by maximum similarity (scores are similarity) + sorted_idxs = sortperm(scores, rev = true) + positions_sorted = @view(positions[sorted_idxs]) + ## get the positions of unique elements + unique_idxs = unique(i -> positions_sorted[i], eachindex(positions_sorted)) + positions = positions_sorted[unique_idxs] + ## apply the sorting and then the filtering + scores = @view(scores[sorted_idxs])[unique_idxs] + else + positions = unique(positions) + end + + CandidateChunks(cc1.index_id, positions, scores) +end + +# combine/intersect two candidate chunks. take the maximum of the score if available function Base.var"&"(cc1::AbstractCandidateChunks, cc2::AbstractCandidateChunks) - throw(ArgumentError("Not implemented")) + throw(ArgumentError("Not implemented for type $(typeof(cc1)) and $(typeof(cc2))")) end function Base.var"&"(cc1::CandidateChunks{TP1, TD1}, cc2::CandidateChunks{TP2, TD2}) where @@ -114,12 +178,38 @@ function Base.var"&"(cc1::CandidateChunks{TP1, TD1}, cc1.index_id != cc2.index_id && return CandidateChunks(; index_id = cc1.index_id) positions = intersect(cc1.positions, cc2.positions) - distances = if !isempty(cc1.distances) && !isempty(cc2.distances) - (cc1.distances[positions] .+ cc2.distances[positions]) ./ 2 + + scores = if !isempty(cc1.scores) && !isempty(cc2.scores) + valid_scores = fill(TD1(-1), length(positions)) + # identify maximum scores from each CC + # scan the first CC + for i in eachindex(cc1.positions, cc1.scores) + pos = cc1.positions[i] + idx = findfirst(==(pos), positions) + if !isnothing(idx) + valid_scores[idx] = max(valid_scores[idx], cc1.scores[i]) + end + end + # scan the second CC + for i in eachindex(cc2.positions, cc2.scores) + pos = cc2.positions[i] + idx = findfirst(==(pos), positions) + if !isnothing(idx) + valid_scores[idx] = max(valid_scores[idx], cc2.scores[i]) + end + end + valid_scores else - Float32[] + TD1[] end - CandidateChunks(cc1.index_id, positions, distances) + ## Sort by maximum similarity + if !isempty(scores) + sorted_idxs = sortperm(scores, rev = true) + positions = positions[sorted_idxs] + scores = scores[sorted_idxs] + end + + CandidateChunks(cc1.index_id, positions, scores) end function Base.getindex(ci::AbstractDocumentIndex, @@ -184,30 +274,65 @@ function Base.getindex(mi::MultiIndex, end """ - RAGDetails + RAGResult A struct for debugging RAG answers. It contains the question, answer, context, and the candidate chunks at each step of the RAG pipeline. + +Think of the flow as `question` -> `rephrased_questions` -> `answer` -> `final_answer` with the context and candidate chunks helping along the way. + +# Fields +- `question::AbstractString`: the original question +- `rephrased_questions::Vector{<:AbstractString}`: a vector of rephrased questions (eg, HyDe, Multihop, etc.) +- `answer::AbstractString`: the generated answer +- `final_answer::AbstractString`: the refined final answer (eg, after CorrectiveRAG), also considered the FINAL answer (it must be always available) +- `context::Vector{<:AbstractString}`: the context used for retrieval (ie, the vector of chunks and their surrounding window if applicable) +- `sources::Vector{<:AbstractString}`: the sources of the context (for the original matched chunks) +- `emb_candidates::CandidateChunks`: the candidate chunks from the embedding index (from `find_closest`) +- `tag_candidates::Union{Nothing, CandidateChunks}`: the candidate chunks from the tag index (from `find_tags`) +- `filtered_candidates::CandidateChunks`: the filtered candidate chunks (intersection of `emb_candidates` and `tag_candidates`) +- `reranked_candidates::CandidateChunks`: the reranked candidate chunks (from `rerank`) +- `conversations::Dict{Symbol,Vector{<:AbstractMessage}}`: the conversation history for AI steps of the RAG pipeline, use keys that correspond to the function names, eg, `:answer` or `:refine` + +See also: `pprint` (pretty printing), `annotate_support` (for annotating the answer) """ -@kwdef mutable struct RAGDetails <: AbstractRAGResult +@kwdef mutable struct RAGResult <: AbstractRAGResult question::AbstractString - rephrased_question::AbstractVector{<:AbstractString} - answer::AbstractString - refined_answer::AbstractString - context::Vector{<:AbstractString} - sources::Vector{<:AbstractString} - emb_candidates::CandidateChunks - tag_candidates::Union{Nothing, CandidateChunks} - filtered_candidates::CandidateChunks - reranked_candidates::CandidateChunks + rephrased_questions::AbstractVector{<:AbstractString} = [question] + answer::Union{Nothing, AbstractString} = nothing + final_answer::Union{Nothing, AbstractString} = nothing + context::Vector{<:AbstractString} = String[] + sources::Vector{<:AbstractString} = String[] + emb_candidates::CandidateChunks = CandidateChunks( + index_id = :NOTINDEX, positions = Int[], scores = Float32[]) + tag_candidates::Union{Nothing, CandidateChunks} = CandidateChunks( + index_id = :NOTINDEX, positions = Int[], scores = Float32[]) + filtered_candidates::CandidateChunks = CandidateChunks( + index_id = :NOTINDEX, positions = Int[], scores = Float32[]) + reranked_candidates::CandidateChunks = CandidateChunks( + index_id = :NOTINDEX, positions = Int[], scores = Float32[]) + conversations::Dict{Symbol, Vector{<:AbstractMessage}} = Dict{ + Symbol, Vector{<:AbstractMessage}}() end # Simplification of the RAGDetails struct -function RAGDetails( - question, answer, context; sources = ["Source $i" for i in 1:length(context)]) - return RAGDetails(question, [question], answer, answer, context, sources, - CandidateChunks(index_id = :emb, positions = Int[], distances = Float32[]), - nothing, - CandidateChunks(index_id = :emb, positions = Int[], distances = Float32[]), - CandidateChunks(index_id = :emb, positions = Int[], distances = Float32[])) +## function RAGResult( +## question::AbstractString, answer::AbstractString, context::Vector{<:AbstractString}; +## sources = ["Source $i" for i in 1:length(context)]) +## return RAGResult(question, [question], answer, answer, context, sources, +## CandidateChunks(index_id = :index, positions = Int[], scores = Float32[]), +## nothing, +## CandidateChunks(index_id = :index, positions = Int[], scores = Float32[]), +## CandidateChunks(index_id = :index, positions = Int[], scores = Float32[]), +## Dict{Symbol, Vector{<:AbstractMessage}}()) +## end + +function Base.var"=="(r1::T, r2::T) where {T <: AbstractRAGResult} + all(f -> getfield(r1, f) == getfield(r2, f), + fieldnames(T)) +end +function Base.copy(r::T) where {T <: AbstractRAGResult} + T([deepcopy(getfield(r, f)) + + for f in fieldnames(T)]...) end # Structured show method for easier reading (each kwarg on a new line) @@ -217,21 +342,43 @@ function Base.show(io::IO, end # Pretty print +# TODO: add more customizations, eg, context itself +""" + PT.pprint( + io::IO, r::AbstractRAGResult; add_context::Bool = false, + text_width::Int = displaysize(io)[2], annotater_kwargs...) + +Pretty print the RAG result `r` to the given `io` stream. + +If `add_context` is `true`, the context will be printed as well. The `text_width` parameter can be used to control the width of the output. + +You can provide additional keyword arguments to the annotater, eg, `add_sources`, `add_scores`, `min_score`, etc. See `annotate_support` for more details. +""" function PT.pprint( - io::IO, r::AbstractRAGResult; text_width::Int = displaysize(io)[2]) - if !isempty(r.rephrased_question) - content = PT.wrap_string("- " * join(r.rephrased_question, "\n- "), text_width) + io::IO, r::AbstractRAGResult; add_context::Bool = false, + text_width::Int = displaysize(io)[2], annotater_kwargs...) + if !isempty(r.rephrased_questions) + content = PT.wrap_string("- " * join(r.rephrased_questions, "\n- "), text_width) print(io, "-"^20, "\n") printstyled(io, "QUESTION(s)", color = :blue, bold = true) print(io, "\n", "-"^20, "\n") print(io, content, "\n\n") end - if !isempty(r.refined_answer) + if !isempty(r.final_answer) annotater = TrigramAnnotater() - root = annotate_support(annotater, r) + root = annotate_support(annotater, r; annotater_kwargs...) print(io, "-"^20, "\n") printstyled(io, "ANSWER", color = :blue, bold = true) print(io, "\n", "-"^20, "\n") pprint(io, root; text_width) end + if add_context + print(io, "-"^20, "\n") + printstyled(io, "CONTEXT", color = :blue, bold = true) + print(io, "\n", "-"^20, "\n") + for (i, ctx) in enumerate(r.context) + print(io, "$(i). ", PT.wrap_string(ctx, text_width)) + print(io, "\n", "-"^20, "\n") + end + end end \ No newline at end of file diff --git a/src/Experimental/RAGTools/utils.jl b/src/Experimental/RAGTools/utils.jl index e7a9884a..afb7897f 100644 --- a/src/Experimental/RAGTools/utils.jl +++ b/src/Experimental/RAGTools/utils.jl @@ -44,7 +44,8 @@ Tokenizes provided `input` by spaces, special characters or Julia symbols (eg, ` Unlike other tokenizers, it aims to lossless - ie, keep both the separated text and the separators. """ function tokenize(input::Union{String, SubString{String}}) - pattern = r"(\s+|=>|\(;|,|\.|\(|\)|\{|\}|\[|\]|;|:|\+|-|\*|/|<|>|=|&|\||!|@|#|\$|%|\^|~|`|\"|'|\w+)" + # specific to Julia language pattern, eg, capture macros (@xyz) or common operators (=>) + pattern = r"(\s+|=>|\(;|,|\.|\(|\)|\{|\}|\[|\]|;|:|\+|-|\*|/|<|>|=|&|\||!|@\w+|@|#|\$|%|\^|~|`|\"|'|\w+)" SubString{String}[m.match for m in eachmatch(pattern, input)] end @@ -204,7 +205,7 @@ function split_into_code_and_sentences(input::Union{String, SubString{String}}) pattern = r"(```[\s\S]+?```)|(`[^`]*?`)|([^`]+)" ## Patterns for sentences: newline, tab, bullet, enumerate list, sentence, any left out characters - sentence_pattern = r"(\n|\t|^\s*[*+-]\s*|^\s*\d+\.\s+|[^\n\t*+\-.!?]+[\n\t*+\-.!?]*|[*+\-.!?])"ms + sentence_pattern = r"(\n|\t|^\s*[*+-]\s*|^\s*\d+\.\s+|[^\n\t\.!?]+[\.!?]*|[*+\-\.!?])"ms # Initialize an empty array to store the split sentences sentences = SubString{String}[] diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index 42e01800..77cde854 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -1,5 +1,6 @@ module PromptingTools +import AbstractTrees using Base64: base64encode using Logging using OpenAI @@ -28,6 +29,7 @@ const RESERVED_KWARGS = [ # export replace_words, recursive_splitter, split_by_length, call_cost, auth_header # for debugging only # export length_longest_common_subsequence, distance_longest_common_subsequence +# export pprint include("utils.jl") export aigenerate, aiembed, aiclassify, aiextract, aiscan, aiimage diff --git a/src/serialization.jl b/src/serialization.jl index a21bcd6d..93455561 100644 --- a/src/serialization.jl +++ b/src/serialization.jl @@ -1,5 +1,14 @@ ## Loading / Saving -"Saves provided messaging template (`messages`) to `io_or_file`. Automatically adds metadata based on provided keyword arguments." +""" + save_template(io_or_file::Union{IO, AbstractString}, + messages::AbstractVector{<:AbstractChatMessage}; + content::AbstractString = "Template Metadata", + description::AbstractString = "", + version::AbstractString = "1", + source::AbstractString = "") + +Saves provided messaging template (`messages`) to `io_or_file`. Automatically adds metadata based on provided keyword arguments. +""" function save_template(io_or_file::Union{IO, AbstractString}, messages::AbstractVector{<:AbstractChatMessage}; content::AbstractString = "Template Metadata", @@ -13,7 +22,11 @@ function save_template(io_or_file::Union{IO, AbstractString}, # save template to IO or file JSON3.write(io_or_file, [metadata_msg, messages...]) end -"Loads messaging template from `io_or_file` and returns tuple of template messages and metadata." +""" + load_template(io_or_file::Union{IO, AbstractString}) + +Loads messaging template from `io_or_file` and returns tuple of template messages and metadata. +""" function load_template(io_or_file::Union{IO, AbstractString}) messages = JSON3.read(io_or_file, Vector{AbstractChatMessage}) template, metadata = AbstractChatMessage[], MetadataMessage[] @@ -29,12 +42,21 @@ function load_template(io_or_file::Union{IO, AbstractString}) end ## Variants without metadata: -"Saves provided conversation (`messages`) to `io_or_file`. If you need to add some metadata, see `save_template`." +""" + save_conversation(io_or_file::Union{IO, AbstractString}, + messages::AbstractVector{<:AbstractMessage}) + +Saves provided conversation (`messages`) to `io_or_file`. If you need to add some metadata, see `save_template`. +""" function save_conversation(io_or_file::Union{IO, AbstractString}, messages::AbstractVector{<:AbstractMessage}) JSON3.write(io_or_file, messages) end -"Loads a conversation (`messages`) from `io_or_file`" +""" + load_conversation(io_or_file::Union{IO, AbstractString}) + +Loads a conversation (`messages`) from `io_or_file` +""" function load_conversation(io_or_file::Union{IO, AbstractString}) messages = JSON3.read(io_or_file, Vector{AbstractMessage}) end diff --git a/src/templates.jl b/src/templates.jl index 016b30cf..399b41a1 100644 --- a/src/templates.jl +++ b/src/templates.jl @@ -398,6 +398,7 @@ If `load_as` is provided, it registers the template in the `TEMPLATE_STORE` and # Examples Let's generate a quick template for a simple conversation (only one placeholder: name) + ```julia # first system message, then user message (or use kwargs) tpl=PT.create_template("You must speak like a pirate", "Say hi to {{name}}") @@ -422,6 +423,7 @@ PT.save_template("templates/GreatingPirate.json", tpl; version="1.0") # optional It will be saved and accessed under its basename, ie, `GreatingPirate`. Now you can load it like all the other templates (provide the template directory): + ```julia PT.load_templates!("templates") # it will remember the folder after the first run # Note: If you save it again, overwrite it, etc., you need to explicitly reload all templates again! @@ -450,13 +452,14 @@ aigenerate(:GreatingPirate; name="Jack Sparrow") ``` If you do not need to save this template as a file, but you want to make it accessible in the template store for all `ai*` functions, you can use the `load_as` (= template name) keyword argument: + ```julia # this will not only create the template, but also register it for immediate use tpl=PT.create_template("You must speak like a pirate", "Say hi to {{name}}"; load_as="GreatingPirate") # you can now use it like any other template aiextract(:GreatingPirate; name="Jack Sparrow") -```` +``` """ function create_template( system::AbstractString, diff --git a/src/user_preferences.jl b/src/user_preferences.jl index f4fa597c..893555bf 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -159,7 +159,7 @@ _temp = get(ENV, "FIREWORKS_API_KEY", "") const FIREWORKS_API_KEY::String = @load_preference("FIREWORKS_API_KEY", default=_temp); -_temp = get(ENV, "LOCAL_SERVER", "") +_temp = get(ENV, "LOCAL_SERVER", "http://localhost:10897/v1") ## Address of the local server const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER", default=_temp); @@ -448,7 +448,12 @@ registry = Dict{String, ModelSpec}( LocalServerOpenAISchema(), 0.0, 0.0, - "Local server, eg, powered by [Llama.jl](https://github.com/marcom/Llama.jl). Model is specified when instantiating the server itself."), + "Local server, eg, powered by [Llama.jl](https://github.com/marcom/Llama.jl). Model is specified when instantiating the server itself. It will be automatically pointed to the address in `LOCAL_SERVER`."), + "custom" => ModelSpec("custom", + LocalServerOpenAISchema(), + 0.0, + 0.0, + "Send a generic request to a custom server. Make sure to explicitly define the `api_kwargs = (; url = ...)` when calling the model."), "gemini-pro" => ModelSpec("gemini-pro", GoogleSchema(), 0.0, #unknown, expected 1.25e-7 diff --git a/src/utils.jl b/src/utils.jl index 1029c3ea..b47e318b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -608,8 +608,8 @@ function pprint(io::IO, anything::Any; text_width::Int = displaysize(io)[2]) end function pprint(anything::Any; - text_width = displaysize(stdout)[2]) - pprint(stdout, anything; text_width) + text_width = displaysize(stdout)[2], kwargs...) + pprint(stdout, anything; text_width, kwargs...) end """ diff --git a/templates/RAG/RAGAnswerFromContext.json b/templates/RAG/basic-rag/RAGAnswerFromContext.json similarity index 100% rename from templates/RAG/RAGAnswerFromContext.json rename to templates/RAG/basic-rag/RAGAnswerFromContext.json diff --git a/templates/RAG/RAGCreateQAFromContext.json b/templates/RAG/evaluation/RAGCreateQAFromContext.json similarity index 100% rename from templates/RAG/RAGCreateQAFromContext.json rename to templates/RAG/evaluation/RAGCreateQAFromContext.json diff --git a/templates/RAG/RAGJudgeAnswerFromContext.json b/templates/RAG/evaluation/RAGJudgeAnswerFromContext.json similarity index 100% rename from templates/RAG/RAGJudgeAnswerFromContext.json rename to templates/RAG/evaluation/RAGJudgeAnswerFromContext.json diff --git a/templates/RAG/RAGJudgeAnswerFromContextShort.json b/templates/RAG/evaluation/RAGJudgeAnswerFromContextShort.json similarity index 100% rename from templates/RAG/RAGJudgeAnswerFromContextShort.json rename to templates/RAG/evaluation/RAGJudgeAnswerFromContextShort.json diff --git a/templates/RAG/RAGExtractMetadataLong.json b/templates/RAG/metadata/RAGExtractMetadataLong.json similarity index 100% rename from templates/RAG/RAGExtractMetadataLong.json rename to templates/RAG/metadata/RAGExtractMetadataLong.json diff --git a/templates/RAG/RAGExtractMetadataShort.json b/templates/RAG/metadata/RAGExtractMetadataShort.json similarity index 100% rename from templates/RAG/RAGExtractMetadataShort.json rename to templates/RAG/metadata/RAGExtractMetadataShort.json diff --git a/templates/RAG/query-transformations/RAGJuliaQueryHyDE.json b/templates/RAG/query-transformations/RAGJuliaQueryHyDE.json new file mode 100644 index 00000000..b4d059f6 --- /dev/null +++ b/templates/RAG/query-transformations/RAGJuliaQueryHyDE.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"For Julia-specific RAG applications (rephrase step), inspired by the HyDE approach where it generates a hypothetical passage that answers the provided user query to improve the matched results. This explicitly requires and optimizes for Julia-specific questions. Placeholders: `query`","version":"1.0","source":"","_type":"metadatamessage"},{"content":"You're an world-class AI assistant specialized in Julia language questions.\n\nYour task is to generate a BRIEF and SUCCINCT hypothetical passage from Julia language ecosystem documentation that answers the provided query.\n\nQuery: {{query}}","variables":["query"],"_type":"systemmessage"},{"content":"Write a hypothetical snippet with 20-30 words that would be the perfect answer to the query. Try to include as many key details as possible. \n\nPassage: ","variables":[],"_type":"usermessage"}] \ No newline at end of file diff --git a/templates/RAG/query-transformations/RAGQueryHyDE.json b/templates/RAG/query-transformations/RAGQueryHyDE.json new file mode 100644 index 00000000..7feab50c --- /dev/null +++ b/templates/RAG/query-transformations/RAGQueryHyDE.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"For RAG applications (rephrase step), inspired by the HyDE paper where it generates a hypothetical passage that answers the provided user query to improve the matched results. Placeholders: `query`","version":"1.0","source":"Adapted from [LlamaIndex](https://github.com/run-llama/llama_index/blob/78af3400ad485e15862c06f0c4972dc3067f880c/llama-index-core/llama_index/core/prompts/default_prompts.py#L351)","_type":"metadatamessage"},{"content":"You are a world-class search expert specializing in query transformations.\n\nYour task is to write a hypothetical passage that would answer the below question in the most effective way possible.\n\nIt must have 20-30 words and be directly aligned with the intended search objective.\nTry to include as many key details as possible.","variables":[],"_type":"systemmessage"},{"content":"Query: {{query}}\n\nPassage: ","variables":["query"],"_type":"usermessage"}] \ No newline at end of file diff --git a/templates/RAG/query-transformations/RAGQueryOptimizer.json b/templates/RAG/query-transformations/RAGQueryOptimizer.json new file mode 100644 index 00000000..f6731eeb --- /dev/null +++ b/templates/RAG/query-transformations/RAGQueryOptimizer.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"For RAG applications (rephrase step), it rephrases the original query to attract more diverse set of potential search results. Placeholders: `query`","version":"1.0","source":"Adapted from [LlamaIndex](https://github.com/run-llama/llama_index/blob/78af3400ad485e15862c06f0c4972dc3067f880c/llama-index-packs/llama-index-packs-corrective-rag/llama_index/packs/corrective_rag/base.py#L11)","_type":"metadatamessage"},{"content":"You are a world-class search expert specializing in query rephrasing.\nYour task is to refine the provided query to ensure it is highly effective for retrieving relevant search results.\nAnalyze the given input to grasp the core semantic intent or meaning.\n","variables":[],"_type":"systemmessage"},{"content":"Original Query: {{query}}\n\nYour goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective.\nRespond with the optimized query only.\n\nOptimized query: ","variables":["query"],"_type":"usermessage"}] \ No newline at end of file diff --git a/templates/RAG/query-transformations/RAGQuerySimplifier.json b/templates/RAG/query-transformations/RAGQuerySimplifier.json new file mode 100644 index 00000000..bf0e54b4 --- /dev/null +++ b/templates/RAG/query-transformations/RAGQuerySimplifier.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"For RAG applications (rephrase step), it rephrases the original query by stripping unnecessary details to improve the matched results. Placeholders: `query`","version":"1.0","source":"Adapted from [Langchain](https://python.langchain.com/docs/integrations/retrievers/re_phrase)","_type":"metadatamessage"},{"content":"You are an assistant tasked with taking a natural language query from a user and converting it into a query for a vectorstore. \nIn this process, you strip out information that is not relevant for the retrieval task.","variables":[],"_type":"systemmessage"},{"content":"Here is the user query: {{query}}\n\nRephrased query: ","variables":["query"],"_type":"usermessage"}] \ No newline at end of file diff --git a/templates/RAG/refinement/RAGAnswerRefiner.json b/templates/RAG/refinement/RAGAnswerRefiner.json new file mode 100644 index 00000000..4c112759 --- /dev/null +++ b/templates/RAG/refinement/RAGAnswerRefiner.json @@ -0,0 +1 @@ +[{"content":"Template Metadata","description":"For RAG applications (refine step), gives model the ability to refine its answer based on some additional context etc.. The hope is that it better answers the original query. Placeholders: `query`, `answer`, `context`","version":"1.0","source":"Adapted from [LlamaIndex](https://github.com/run-llama/llama_index/blob/78af3400ad485e15862c06f0c4972dc3067f880c/llama-index-core/llama_index/core/prompts/default_prompts.py#L81)","_type":"metadatamessage"},{"content":"Act as a world-class AI assistant with access to the latest knowledge via Context Information.\n\nYour task is to refine an existing answer if it's needed.\n\nThe original query is as follows: \n{{query}}\n\nThe AI model has provided the following answer:\n{{answer}}\n\n**Instructions:**\n- Given the new context, refine the original answer to better answer the query.\n- If the context isn't useful, return the original answer.\n- If you don't know the answer, just say that you don't know, don't try to make up an answer.\n- Be brief and concise.\n- Provide the refined answer only and nothing else.\n\n","variables":["query","answer"],"_type":"systemmessage"},{"content":"We have the opportunity to refine the previous answer (only if needed) with some more context below.\n\n**Context Information:**\n-----------------\n{{context}}\n-----------------\n\nGiven the new context, refine the original answer to better answer the query.\nIf the context isn't useful, return the original answer. \nProvide the refined answer only and nothing else.\n\nRefined Answer: ","variables":["context"],"_type":"usermessage"}] \ No newline at end of file diff --git a/test/Experimental/RAGTools/annotation.jl b/test/Experimental/RAGTools/annotation.jl index bcd745e6..8ab7e43b 100644 --- a/test/Experimental/RAGTools/annotation.jl +++ b/test/Experimental/RAGTools/annotation.jl @@ -5,7 +5,7 @@ using PromptingTools.Experimental.RAGTools: AnnotatedNode, AbstractAnnotater, HTMLStyler, pprint using PromptingTools.Experimental.RAGTools: trigram_support!, add_node_metadata!, - annotate_support, RAGDetails, text_to_trigrams + annotate_support, RAGResult, text_to_trigrams @testset "AnnotatedNode" begin # Test node creation with default values @@ -328,10 +328,16 @@ end @test occursin("\nSOURCES\n", output) @test occursin("1. Source 1", output) + # Catch empty context + answer = "This is a test answer." + @test_throws AssertionError annotated_root=annotate_support( + annotater, answer, String[]) + ## RAG Details dispatch answer = "This is a test answer." - r = RAGDetails( - "?", answer, context; sources = ["Source 1", "Source 2", "Source 3"]) + r = RAGResult(; + question = "?", final_answer = answer, context, sources = [ + "Source 1", "Source 2", "Source 3"]) annotated_root = annotate_support(annotater, r) io = IOBuffer() pprint(io, annotated_root) diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl index 4a8693c0..94539bcc 100644 --- a/test/Experimental/RAGTools/evaluation.jl +++ b/test/Experimental/RAGTools/evaluation.jl @@ -1,7 +1,7 @@ using PromptingTools.Experimental.RAGTools: QAItem, QAEvalItem, QAEvalResult using PromptingTools.Experimental.RAGTools: score_retrieval_hit, score_retrieval_rank using PromptingTools.Experimental.RAGTools: build_qa_evals, run_qa_evals, chunks, sources -using PromptingTools.Experimental.RAGTools: JudgeAllScores, MetadataItem, MaybeMetadataItems +using PromptingTools.Experimental.RAGTools: JudgeAllScores, Tag, MaybeTags @testset "QAEvalItem" begin empty_qa = QAEvalItem() @@ -75,7 +75,7 @@ end @testset "build_qa_evals" begin # test with a mock server - PORT = rand(10000:40000) + PORT = rand(10005:40001) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) @@ -107,8 +107,8 @@ end :choices => [ Dict(:finish_reason => "stop", :message => Dict(:tool_calls => [ - Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category") + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "category") ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), @@ -176,19 +176,24 @@ end String[]; qa_template = :BlankSystemUser) # Test run_qa_evals on 1 item - msg, ctx = airag(index; question = qa_evals[1].question, model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"), - tag_filter = :auto, - extract_metadata = false, verbose = false, - return_details = true) - - result = run_qa_evals(qa_evals[1], ctx; + airag_kwargs = (; + retriever_kwargs = (; + tagger_kwargs = (; model = "mock-gen", tag = ["yes"]), embedder_kwargs = (; + model = "mock-emb")), + generator_kwargs = (; + answerer_kwargs = (; model = "mock-gen"), embedder_kwargs = (; + model = "mock-emb"))) + result = airag(RAGConfig(), index; question = qa_evals[1].question, + airag_kwargs..., + api_kwargs = (; url = "http://localhost:$(PORT)"), + return_all = true) + + result = run_qa_evals(qa_evals[1], result; model_judge = "mock-judge", api_kwargs = (; url = "http://localhost:$(PORT)"), parameters_dict = Dict(:key1 => "value1", :key2 => 2)) @test result.retrieval_score == 1.0 - @test result.retrieval_rank == 1 + @test result.retrieval_rank == 2 @test result.answer_score == 5 @test result.parameters == Dict(:key1 => "value1", :key2 => 2) @@ -196,17 +201,14 @@ end # results = run_qa_evals(index, qa_evals; model_judge = "mock-judge", # api_kwargs = (; url = "http://localhost:$(PORT)")) results = run_qa_evals(index, qa_evals; - airag_kwargs = (; - model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta"), + airag_kwargs, qa_evals_kwargs = (; model_judge = "mock-judge"), api_kwargs = (; url = "http://localhost:$(PORT)"), parameters_dict = Dict(:key1 => "value1", :key2 => 2)) @test length(results) == length(qa_evals) @test all(getproperty.(results, :retrieval_score) .== 1.0) - @test all(getproperty.(results, :retrieval_rank) .== 1) + @test all(x -> x.retrieval_rank in [1, 2], results) @test all(getproperty.(results, :answer_score) .== 5) @test all(getproperty.(results, :parameters) .== Ref(Dict(:key1 => "value1", :key2 => 2))) diff --git a/test/Experimental/RAGTools/generation.jl b/test/Experimental/RAGTools/generation.jl index 4cdc06c5..8c038b40 100644 --- a/test/Experimental/RAGTools/generation.jl +++ b/test/Experimental/RAGTools/generation.jl @@ -1,8 +1,16 @@ using PromptingTools.Experimental.RAGTools: ChunkIndex, - CandidateChunks, build_context, airag -using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, MetadataItem + CandidateChunks, build_context, build_context! +using PromptingTools.Experimental.RAGTools: MaybeTags, Tag, ContextEnumerator, + AbstractContextBuilder +using PromptingTools.Experimental.RAGTools: SimpleAnswerer, AbstractAnswerer, answer!, + NoRefiner, SimpleRefiner, AbstractRefiner, + refine! +using PromptingTools.Experimental.RAGTools: NoPostprocessor, AbstractPostprocessor, + postprocess!, SimpleGenerator, + AdvancedGenerator, generate!, airag, RAGConfig, + RAGResult -@testset "build_context" begin +@testset "build_context!" begin index = ChunkIndex(; sources = [".", ".", "."], chunks = ["a", "b", "c"], @@ -12,26 +20,184 @@ using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, MetadataItem candidates = CandidateChunks(index.id, [1, 2], [0.1, 0.2]) # Standard Case - context = build_context(index, candidates) + contexter = ContextEnumerator() + context = build_context(contexter, index, candidates) expected_output = ["1. a\nb", "2. a\nb\nc"] @test context == expected_output # No Surrounding Chunks - context = build_context(index, candidates; chunks_window_margin = (0, 0)) + context = build_context(contexter, index, candidates; chunks_window_margin = (0, 0)) expected_output = ["1. a", "2. b"] @test context == expected_output # Wrong inputs - @test_throws AssertionError build_context(index, + @test_throws AssertionError build_context(contexter, index, candidates; chunks_window_margin = (-1, 0)) + + # From result/index + question = "why?" + result = RAGResult(; + question, rephrased_questions = [question], emb_candidates = candidates, + tag_candidates = candidates, filtered_candidates = candidates, reranked_candidates = candidates, + context = String[], sources = String[]) + build_context!(contexter, index, result) + expected_output = ["1. a\nb", + "2. a\nb\nc"] + @test result.context == expected_output + + # Unknown type + struct RandomContextEnumerator123 <: AbstractContextBuilder end + @test_throws ArgumentError build_context!( + RandomContextEnumerator123(), index, result) +end + +@testset "answer!" begin + # Setup + index = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = ones(Float32, 2, 2)) + + question = "why?" + cc1 = CandidateChunks(index_id = :TestChunkIndex1) + + result = RAGResult(; question, rephrased_questions = [question], emb_candidates = cc1, + tag_candidates = cc1, filtered_candidates = cc1, reranked_candidates = cc1, + context = String["a", "b"], sources = String[]) + + # Test refine with SimpleAnswerer + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "answer"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + + output = answer!( + SimpleAnswerer(), index, result; model = "mock-gen") + @test result.answer == "answer" + @test result.conversations[:answer][end].content == "answer" + + # with unknown rephraser + struct UnknownAnswerer123 <: AbstractAnswerer end + @test_throws ArgumentError answer!(UnknownAnswerer123(), index, result) +end + +@testset "refine!" begin + # Setup + index = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = ones(Float32, 2, 2)) + + question = "why?" + cc1 = CandidateChunks(index_id = :TestChunkIndex1) + + # Test refine with NoRefiner, simple passthrough + result = RAGResult(; question, rephrased_questions = [question], emb_candidates = cc1, + tag_candidates = cc1, filtered_candidates = cc1, reranked_candidates = cc1, + context = String[], sources = String[], answer = "ABC", + conversations = Dict(:answer => [PT.UserMessage("MESSAGE")])) + + result = refine!(NoRefiner(), index, result) + @test result.final_answer == "ABC" + @test result.conversations[:final_answer] == [PT.UserMessage("MESSAGE")] + + # Test refine with SimpleRefiner + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "new answer"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + result = RAGResult(; question, rephrased_questions = [question], emb_candidates = cc1, + tag_candidates = cc1, filtered_candidates = cc1, reranked_candidates = cc1, + context = String[], sources = String[], answer = "ABC", + conversations = Dict(:answer => [PT.UserMessage("MESSAGE")])) + + output = refine!( + SimpleRefiner(), index, result; model = "mock-gen") + @test result.final_answer == "new answer" + @test result.conversations[:final_answer][end].content == "new answer" + + # with unknown rephraser + struct UnknownRefiner123 <: AbstractRefiner end + @test_throws ArgumentError refine!(UnknownRefiner123(), index, result) +end + +@testset "postprocess!" begin + question = "why?" + cc1 = CandidateChunks(index_id = :TestChunkIndex1) + result = RAGResult(; question, rephrased_questions = [question], emb_candidates = cc1, + tag_candidates = cc1, filtered_candidates = cc1, reranked_candidates = cc1, + context = String[], sources = String[]) + index = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = ones(Float32, 2, 2)) + + # passthrough + @test postprocess!(NoPostprocessor(), index, result) == result + # Unknown type + struct RandomPostprocessor123 <: AbstractPostprocessor end + @test_throws ArgumentError postprocess!(RandomPostprocessor123(), index, result) +end + +@testset "generate!" begin + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "answer"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + + index = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = ones(Float32, 2, 2)) + + question = "why?" + cc1 = CandidateChunks(index_id = :TestChunkIndex1) + + result = RAGResult(; question, rephrased_questions = [question], emb_candidates = cc1, + tag_candidates = cc1, filtered_candidates = cc1, reranked_candidates = cc1, + context = String["a", "b"], sources = String[]) + + # SimpleGenerator - no refinement + output = generate!(SimpleGenerator(), index, result; + answerer_kwargs = (; model = "mock-gen")) + @test output.answer == "answer" + @test output.final_answer == "answer" + + # with defaults + output = generate!(index, result; + answerer_kwargs = (; model = "mock-gen")) + @test output.answer == "answer" + @test output.final_answer == "answer" + + # Test with refinement - AdvancedGenerator + output = generate!(AdvancedGenerator(), index, result; + answerer_kwargs = (; model = "mock-gen"), + refiner_kwargs = (; model = "mock-gen")) + @test output.answer == "answer" + @test output.final_answer == "answer" end @testset "airag" begin # test with a mock server - PORT = rand(20000:40000) + PORT = rand(20010:40001) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) @@ -61,8 +227,8 @@ end :choices => [ Dict(:finish_reason => "stop", :message => Dict(:tool_calls => [ - Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category") + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "category") ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), @@ -86,69 +252,53 @@ end model = "mock-emb", api_kwargs = (; url = "http://localhost:$(PORT)")) @test question_emb.content == ones(128) - metadata_msg = aiextract(:RAGExtractMetadataShort; return_type = MaybeMetadataItems, + metadata_msg = aiextract(:RAGExtractMetadataShort; return_type = MaybeTags, text = "x", model = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)")) - @test metadata_msg.content.items == [MetadataItem("yes", "category")] + @test metadata_msg.content.items == [Tag("yes", "category")] answer_msg = aigenerate(:RAGAnswerFromContext; question = "Time?", context = "XYZ", model = "mock-gen", api_kwargs = (; url = "http://localhost:$(PORT)")) @test occursin("Time?", answer_msg.content) - ## E2E - msg = airag(index; question = "Time?", model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"), - tag_filter = ["yes"], - return_details = false) + ## E2E - default type + msg = airag(index; question = "Time?", + retriever_kwargs = (; + tagger_kwargs = (; model = "mock-gen", tag = ["yes"]), embedder_kwargs = (; + model = "mock-emb")), + generator_kwargs = (; + answerer_kwargs = (; model = "mock-gen"), embedder_kwargs = (; + model = "mock-emb")), + api_kwargs = (; url = "http://localhost:$(PORT)"), + return_all = false) @test occursin("Time?", msg.content) - # test kwargs passing - api_kwargs = (; url = "http://localhost:$(PORT)") - msg = airag(index; question = "Time?", model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", - tag_filter = ["yes"], - return_details = false, aiembed_kwargs = (; api_kwargs), - aigenerate_kwargs = (; api_kwargs), aiextract_kwargs = (; api_kwargs)) + ## E2E - with type + msg = airag(RAGConfig(), index; question = "Time?", + retriever_kwargs = (; + tagger_kwargs = (; model = "mock-gen", tag = ["yes"]), embedder_kwargs = (; + model = "mock-emb")), + generator_kwargs = (; + answerer_kwargs = (; model = "mock-gen"), embedder_kwargs = (; + model = "mock-emb")), + api_kwargs = (; url = "http://localhost:$(PORT)"), + return_all = false) @test occursin("Time?", msg.content) - ## Test different kwargs - msg, details = airag(index; question = "Time?", model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"), - tag_filter = :auto, - extract_metadata = false, verbose = false, - return_details = true) - @test details.context == ["1. a\nb\nc", "2. a\nb"] - @test details.emb_candidates.positions == [3, 2, 1] - @test details.emb_candidates.distances == zeros(3) - @test details.tag_candidates.positions == [1, 2] - @test details.tag_candidates.distances == ones(2) - @test details.filtered_candidates.positions == [2, 1] #re-sort - @test details.filtered_candidates.distances == 0.5ones(2) - @test details.reranked_candidates.positions == [2, 1] # no change - @test details.reranked_candidates.distances == 0.5ones(2) # no change - - ## Not tag filter - msg, details = airag(index; question = "Time?", model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"), - tag_filter = nothing, - return_details = true) - @test details.context == ["1. b\nc", "2. a\nb\nc", "3. a\nb"] - @test details.emb_candidates.positions == [3, 2, 1] - @test details.emb_candidates.distances == zeros(3) - @test details.tag_candidates == nothing - @test details.filtered_candidates.positions == [3, 2, 1] #re-sort - @test details.reranked_candidates.positions == [3, 2, 1] # no change + ## Return RAG result + result = airag(RAGConfig(), index; question = "Time?", + retriever_kwargs = (; + tagger_kwargs = (; model = "mock-gen", tag = ["yes"]), embedder_kwargs = (; + model = "mock-emb")), + generator_kwargs = (; + answerer_kwargs = (; model = "mock-gen"), embedder_kwargs = (; + model = "mock-emb")), + api_kwargs = (; url = "http://localhost:$(PORT)"), + return_all = true) + @test occursin("Time?", result.answer) + @test occursin("Time?", result.final_answer) ## Pretty printing - result = airag(index; question = "Time?", model_embedding = "mock-emb", - model_chat = "mock-gen", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)"), - tag_filter = nothing, - return_details = true) io = IOBuffer() PT.pprint(io, result) result_str = String(take!(io)) diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl index dc14b087..5d8020cc 100644 --- a/test/Experimental/RAGTools/preparation.jl +++ b/test/Experimental/RAGTools/preparation.jl @@ -1,43 +1,144 @@ -using PromptingTools.Experimental.RAGTools: metadata_extract, MetadataItem -using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, build_tags, build_index +using PromptingTools.Experimental.RAGTools: load_text, FileChunker, TextChunker, + BatchEmbedder, + NoTagger, PassthroughTagger, OpenTagger +using PromptingTools.Experimental.RAGTools: AbstractTagger, AbstractChunker, + AbstractEmbedder, AbstractIndexBuilder +using PromptingTools.Experimental.RAGTools: tags_extract, Tag, MaybeTags +using PromptingTools.Experimental.RAGTools: build_tags, build_index, SimpleIndexer, + get_tags, get_chunks, get_embeddings +using PromptingTools.Experimental.RAGTools: build_tags, build_index +using PromptingTools: TestEchoOpenAISchema -@testset "metadata_extract" begin - # MetadataItem Structure - item = MetadataItem("value", "category") +@testset "load_text" begin + # from file + fp, io = mktemp() + write(io, "text") + close(io) + @test load_text(FileChunker(), fp) == ("text", fp) + @test_throws AssertionError load_text(FileChunker(), "nonexistent" * fp) + + # from provided text + @test load_text(TextChunker(), "text"; source = "POMA") == ("text", "POMA") + @test_throws AssertionError load_text(TextChunker(), "text"; source = "a"^520) # catch long doc - cant be a source + + # unknown chunker + struct RandomChunker123 <: AbstractChunker end + @test_throws ArgumentError load_text(RandomChunker123(), "text") +end + +@testset "get_chunks" begin + ochunks, osources = get_chunks( + TextChunker(), ["text1", "text2"]; max_length = 10, sources = ["doc1", "doc2"]) + @test ochunks == ["text1", "text2"] + @test osources == ["doc1", "doc2"] + + # Mismatch in source length + @test_throws AssertionError get_chunks( + TextChunker(), ["text1", "text2"]; max_length = 10, sources = ["doc1"]) + # too long to be a source + @test_throws AssertionError get_chunks( + TextChunker(), ["text1", "text2"]; max_length = 10, sources = ["a"^520, "b"^520]) + + # FileChunker + fp, io = mktemp() + write(io, "text") + close(io) + fp2, io = mktemp() + write(io, "text2") + close(io) + ochunks, osources = get_chunks( + FileChunker(), [fp, fp2]; max_length = 10) + @test ochunks == ["text", "text2"] + @test osources == [fp, fp2] +end + +@testset "get_embeddings" begin + # corresponds to OpenAI API v1 + response1 = Dict(:data => [Dict(:embedding => ones(128, 2))], + :usage => Dict(:total_tokens => 2, :prompt_tokens => 2, :completion_tokens => 0)) + schema = TestEchoOpenAISchema(; response = response1, status = 200) + PT.register_model!(; name = "mock-emb", schema) + + docs = ["Hello World", "Hello World"] + output = get_embeddings(BatchEmbedder(), docs; model = "mock-emb") + @test size(output) == (128, 2) + + # Unknown type + struct RandomEmbedder123 <: AbstractEmbedder end + @test_throws ArgumentError get_embeddings( + RandomEmbedder123(), ["text1", "text2"]) +end + +@testset "tags_extract" begin + # Tag Structure + item = Tag("value", "category") @test item.value == "value" @test item.category == "category" - # MaybeMetadataItems Structure - items = MaybeMetadataItems([ - MetadataItem("value1", "category1"), - MetadataItem("value2", "category2"), + # MaybeTags Structure + items = MaybeTags([ + Tag("value1", "category1"), + Tag("value2", "category2") ]) @test length(items.items) == 2 @test items.items[1].value == "value1" @test items.items[1].category == "category1" - empty_items = MaybeMetadataItems(nothing) - @test isempty(metadata_extract(empty_items.items)) + empty_items = MaybeTags(nothing) + @test isempty(tags_extract(empty_items.items)) # Metadata Extraction Function - single_item = MetadataItem("DataFrames", "Julia Package") + single_item = Tag("DataFrames", "Julia Package") multiple_items = [ - MetadataItem("pandas", "Software"), - MetadataItem("Python", "Language"), - MetadataItem("DataFrames", "Julia Package"), + Tag("pandas", "Software"), + Tag("Python", "Language"), + Tag("DataFrames", "Julia Package") ] - @test metadata_extract(single_item) == "julia_package:::dataframes" - @test metadata_extract(multiple_items) == + @test tags_extract(single_item) == "julia_package:::dataframes" + @test tags_extract(multiple_items) == ["software:::pandas", "language:::python", "julia_package:::dataframes"] - @test metadata_extract(nothing) == String[] + @test tags_extract(nothing) == String[] +end + +@testset "get_tags" begin + # Unknown Tagger + struct RandomTagger123 <: AbstractTagger end + @test_throws ArgumentError get_tags(RandomTagger123(), String[]) + + # NoTagger + @test get_tags(NoTagger(), String[]) == nothing + + # PassthroughTagger + tags_ = [["tag1"], ["tag2"]] + @test get_tags(PassthroughTagger(), ["doc1", "docs2"]; tags = tags_) == tags_ + @test_throws AssertionError get_tags( + PassthroughTagger(), ["doc1", "docs2"]; tags = [["tag1"]]) # length mismatch + + # OpenTagger - mock server + response = Dict( + :choices => [ + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "categoryx") + ]))))]))], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response = response, status = 200) + PT.register_model!(; name = "mock-meta", schema) + tags_ = get_tags(OpenTagger(), String["Say yes"]; model = "mock-meta") + @test tags_ == [["categoryx:::yes"]] end @testset "build_tags" begin + ## empty tags + @test build_tags(NoTagger(), nothing) == (nothing, nothing) + + tagger = OpenTagger() # Single Tag chunk_metadata = [["tag1"]] - tags_, tags_vocab_ = build_tags(chunk_metadata) + tags_, tags_vocab_ = build_tags(tagger, chunk_metadata) @test length(tags_vocab_) == 1 @test tags_vocab_ == ["tag1"] @@ -46,7 +147,7 @@ end # Multiple Tags with Repetition chunk_metadata = [["tag1", "tag2"], ["tag2", "tag3"]] - tags_, tags_vocab_ = build_tags(chunk_metadata) + tags_, tags_vocab_ = build_tags(tagger, chunk_metadata) @test length(tags_vocab_) == 3 @test tags_vocab_ == ["tag1", "tag2", "tag3"] @@ -55,14 +156,14 @@ end # Empty Metadata chunk_metadata = [String[]] - tags_, tags_vocab_ = build_tags(chunk_metadata) + tags_, tags_vocab_ = build_tags(tagger, chunk_metadata) @test isempty(tags_vocab_) @test size(tags_) == (1, 0) # Mixed Empty and Non-Empty Metadata chunk_metadata = [["tag1"], String[], ["tag2", "tag3"]] - tags_, tags_vocab_ = build_tags(chunk_metadata) + tags_, tags_vocab_ = build_tags(tagger, chunk_metadata) @test length(tags_vocab_) == 3 @test tags_vocab_ == ["tag1", "tag2", "tag3"] @@ -72,37 +173,40 @@ end @testset "build_index" begin # test with a mock server - PORT = rand(9000:11000) + PORT = rand(9000:31000) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) - PT.register_model!(; name = "mock-get", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) echo_server = HTTP.serve!(PORT; verbose = -1) do req content = JSON3.read(req.body) if content[:model] == "mock-gen" user_msg = last(content[:messages]) - response = Dict(:choices => [ - Dict(:message => user_msg, :finish_reason => "stop"), + response = Dict( + :choices => [ + Dict(:message => user_msg, :finish_reason => "stop") ], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), :completion_tokens => 0)) elseif content[:model] == "mock-emb" - response = Dict(:data => [Dict(:embedding => ones(Float32, 128)) - for i in 1:length(content[:input])], + response = Dict( + :data => [Dict(:embedding => ones(Float32, 128)) + for i in 1:length(content[:input])], :usage => Dict(:total_tokens => length(content[:input]), :prompt_tokens => length(content[:input]), :completion_tokens => 0)) elseif content[:model] == "mock-meta" user_msg = last(content[:messages]) - response = Dict(:choices => [ + response = Dict( + :choices => [ Dict(:finish_reason => "stop", - :message => Dict(:tool_calls => [ - Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category"), - ]))))]))], + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "category") + ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -114,32 +218,64 @@ end end text = "This is a long text that will be split into chunks.\n\n It will be split by the separator. And also by the separator '\n'." + + ## Default - file reader tmp, _ = mktemp() write(tmp, text) mini_files = [tmp, tmp] - index = build_index(mini_files; max_length = 10, extract_metadata = true, - model_embedding = "mock-emb", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)")) - @test index.embeddings == hcat(fill(normalize(ones(Float32, 128)), 8)...) - @test index.chunks[1:4] == index.chunks[5:8] - @test index.sources == fill(tmp, 8) - @test index.tags == ones(8, 1) - @test index.tags_vocab == ["category:::yes"] + indexer = SimpleIndexer() + index = build_index( + indexer, mini_files; chunker = FileChunker(), chunker_kwargs = (; max_length = 10), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.embeddings == + hcat(fill(normalize(ones(Float32, 128)), length(index.chunks))...) + @test index.chunks[begin:(length(index.chunks) ÷ 2)] == + index.chunks[((length(index.chunks) ÷ 2) + 1):end] + @test index.sources == fill(tmp, length(index.chunks)) + @test index.tags == nothing + @test index.tags_vocab == nothing - ## Test docs reader - index = build_index([text, text]; reader = :docs, sources = ["x", "x"], max_length = 10, - extract_metadata = true, - model_embedding = "mock-emb", - model_metadata = "mock-meta", api_kwargs = (; url = "http://localhost:$(PORT)")) - @test index.embeddings == hcat(fill(normalize(ones(Float32, 128)), 8)...) - @test index.chunks[1:4] == index.chunks[5:8] - @test index.sources == fill("x", 8) - @test index.tags == ones(8, 1) + ## With metadata + indexer = SimpleIndexer(; chunker = FileChunker(), tagger = OpenTagger()) + index = build_index(indexer, mini_files; chunker_kwargs = (; max_length = 10), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.tags == ones(30, 1) @test index.tags_vocab == ["category:::yes"] - # Assertion if sources is missing - @test_throws AssertionError build_index([text, text]; reader = :docs) + ## Test docs reader - customize via kwarg + indexer = SimpleIndexer() + index = build_index(indexer, [text, text]; chunker = TextChunker(), + chunker_kwargs = (; + sources = ["x", "x"], max_length = 10), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.embeddings == + hcat(fill(normalize(ones(Float32, 128)), length(index.chunks))...) + @test index.chunks[begin:(length(index.chunks) ÷ 2)] == + index.chunks[((length(index.chunks) ÷ 2) + 1):end] + @test index.sources == fill("x", length(index.chunks)) + @test index.tags == nothing + @test index.tags_vocab == nothing + # Test default behavior - text chunker + index = build_index([text, text]; + chunker_kwargs = (; + sources = ["x", "x"], max_length = 10), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.embeddings == + hcat(fill(normalize(ones(Float32, 128)), length(index.chunks))...) + @test index.chunks[begin:(length(index.chunks) ÷ 2)] == + index.chunks[((length(index.chunks) ÷ 2) + 1):end] + @test index.sources == fill("x", length(index.chunks)) + @test index.tags == nothing + @test index.tags_vocab == nothing # clean up close(echo_server) end diff --git a/test/Experimental/RAGTools/retrieval.jl b/test/Experimental/RAGTools/retrieval.jl index 4328397f..e987c70f 100644 --- a/test/Experimental/RAGTools/retrieval.jl +++ b/test/Experimental/RAGTools/retrieval.jl @@ -1,27 +1,62 @@ -using PromptingTools.Experimental.RAGTools: find_closest, find_tags -using PromptingTools.Experimental.RAGTools: Passthrough, rerank, CohereRerank +using PromptingTools.Experimental.RAGTools: ContextEnumerator, NoRephraser, SimpleRephraser, + HyDERephraser, + CosineSimilarity, NoTagFilter, AnyTagFilter, + SimpleRetriever, AdvancedRetriever +using PromptingTools.Experimental.RAGTools: AbstractRephraser, AbstractTagFilter, + AbstractSimilarityFinder, AbstractReranker +using PromptingTools.Experimental.RAGTools: find_closest, find_tags, rerank, rephrase, + retrieve +using PromptingTools.Experimental.RAGTools: NoReranker, CohereReranker + +@testset "rephrase" begin + # Test rephrase with NoRephraser, simple passthrough + @test rephrase(NoRephraser(), "test") == ["test"] + + # Test rephrase with SimpleRephraser + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "new question"), :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, + :prompt_tokens => 2, + :completion_tokens => 1)) + schema = TestEchoOpenAISchema(; response, status = 200) + PT.register_model!(; name = "mock-gen", schema) + output = rephrase( + SimpleRephraser(), "old question", model = "mock-gen") + @test output == ["old question", "new question"] + + output = rephrase( + HyDERephraser(), "old question", model = "mock-gen") + @test output == ["old question", "new question"] + + # with unknown rephraser + struct UnknownRephraser123 <: AbstractRephraser end + @test_throws ArgumentError rephrase(UnknownRephraser123(), "test question") +end @testset "find_closest" begin + finder = CosineSimilarity() test_embeddings = [1.0 2.0 -1.0; 3.0 4.0 -3.0; 5.0 6.0 -6.0] |> x -> mapreduce(normalize, hcat, eachcol(x)) query_embedding = [0.1, 0.35, 0.5] |> normalize - positions, distances = find_closest(test_embeddings, query_embedding, top_k = 2) + positions, distances = find_closest(finder, test_embeddings, query_embedding, top_k = 2) # The query vector should be closer to the first embedding @test positions == [1, 2] @test isapprox(distances, [0.9975694083904584 - 0.9939123761133188], atol = 1e-3) + 0.9939123761133188], atol = 1e-3) # Test when top_k is more than available embeddings - positions, _ = find_closest(test_embeddings, query_embedding, top_k = 5) + positions, _ = find_closest(finder, test_embeddings, query_embedding, top_k = 5) @test length(positions) == size(test_embeddings, 2) # Test with minimum_similarity - positions, _ = find_closest(test_embeddings, query_embedding, top_k = 5, + positions, _ = find_closest(finder, test_embeddings, query_embedding, top_k = 5, minimum_similarity = 0.995) @test length(positions) == 1 # Test behavior with edge values (top_k == 0) - @test find_closest(test_embeddings, query_embedding, top_k = 0) == ([], []) + @test find_closest(finder, test_embeddings, query_embedding, top_k = 0) == ([], []) ## Test with ChunkIndex embeddings1 = ones(Float32, 2, 2) @@ -35,16 +70,44 @@ using PromptingTools.Experimental.RAGTools: Passthrough, rerank, CohereRerank chunks = ["chunk1", "chunk2"], sources = ["source1", "source2"], embeddings = ones(Float32, 2, 2)) - mi = MultiIndex(id = :multi, indexes = [ci1, ci2]) + ci3 = ChunkIndex(id = :TestChunkIndex3, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"], + embeddings = nothing) ## find_closest with ChunkIndex query_emb = [0.5, 0.5] # Example query embedding vector - result = find_closest(ci1, query_emb) + result = find_closest(finder, ci1, query_emb) @test result isa CandidateChunks @test result.positions == [1, 2] - @test all(1.0 .>= result.distances .>= -1.0) # Assuming default minimum_similarity + @test all(1.0 .>= result.scores .>= -1.0) # Assuming default minimum_similarity + + ## empty index + query_emb = [0.5, 0.5] # Example query embedding vector + result = find_closest(finder, ci3, query_emb) + @test isempty(result) + + ## Unknown type + struct RandomSimilarityFinder123 <: AbstractSimilarityFinder end + @test_throws ArgumentError find_closest( + RandomSimilarityFinder123(), ones(5, 5), ones(5)) + + ## find_closest with multiple embeddings + query_emb = [0.5 0.5; 0.5 1.0] |> x -> mapreduce(normalize, hcat, eachcol(x)) + result = find_closest(finder, ci1, query_emb; top_k = 2) + @test result.positions == [1, 2] + @test isapprox(result.scores, [1.0, 0.965], atol = 1e-2) + + # bad top_k -- too low, leads to 0 results + result = find_closest(finder, ci1, query_emb; top_k = 1) + @test isempty(result) + # but it works in general, because 1/1 = 1 is a valid top_k + result = find_closest(finder, ci1, query_emb[:, 1]; top_k = 1) + @test result.positions == [1] + @test result.scores == [1.0] ## find_closest with MultiIndex + ## mi = MultiIndex(id = :multi, indexes = [ci1, ci2]) ## query_emb = [0.5, 0.5] # Example query embedding vector ## result = find_closest(mi, query_emb) ## @test result isa CandidateChunks @@ -53,6 +116,7 @@ using PromptingTools.Experimental.RAGTools: Passthrough, rerank, CohereRerank end @testset "find_tags" begin + tagger = AnyTagFilter() test_embeddings = [1.0 2.0; 3.0 4.0; 5.0 6.0] |> x -> mapreduce(normalize, hcat, eachcol(x)) query_embedding = [0.1, 0.35, 0.5] |> normalize @@ -66,56 +130,199 @@ end tags_vocab = test_tags_vocab) # Test for finding the correct positions of a specific tag - @test find_tags(index, "julia").positions == [1] - @test find_tags(index, "julia").distances == [1.0] + @test find_tags(tagger, index, "julia").positions == [1] + @test find_tags(tagger, index, "julia").scores == [1.0] # Test for no tag found // not in vocab - @test find_tags(index, "python").positions |> isempty - @test find_tags(index, "java").positions |> isempty + @test find_tags(tagger, index, "python").positions |> isempty + @test find_tags(tagger, index, "java").positions |> isempty # Test with regex matching - @test find_tags(index, r"^j").positions == [1, 2] + @test find_tags(tagger, index, r"^j").positions == [1, 2] # Test with multiple tags in vocab - @test find_tags(index, ["python", "jr", "x"]).positions == [2] + @test find_tags(tagger, index, ["python", "jr", "x"]).positions == [2] + + # No filter tag -- give everything + cc = find_tags(NoTagFilter(), index, "julia") + @test cc.positions == [1, 2] + @test cc.scores == [0.0, 0.0] + + cc = find_tags(NoTagFilter(), index, nothing) + @test cc.positions == [1, 2] + @test cc.scores == [0.0, 0.0] + + # Unknown type + struct RandomTagFilter123 <: AbstractTagFilter end + @test_throws ArgumentError find_tags(RandomTagFilter123(), index, "hello") + @test_throws ArgumentError find_tags(RandomTagFilter123(), index, ["hello"]) end @testset "rerank" begin # Mock data for testing - index = "mock_index" + ci1 = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2"], + sources = ["source1", "source2"]) question = "mock_question" - candidate_chunks = ["chunk1", "chunk2", "chunk3"] + cc1 = CandidateChunks(index_id = :TestChunkIndex1, + positions = [1, 2], + scores = [0.3, 0.4]) # Passthrough Strategy - strategy = Passthrough() - @test rerank(strategy, index, question, candidate_chunks) == - candidate_chunks + ranker = NoReranker() + reranked = rerank(ranker, ci1, question, cc1) + @test reranked.positions == [2, 1] # gets resorted by score + @test reranked.scores == [0.4, 0.3] + + reranked = rerank(ranker, ci1, question, cc1; top_n = 1) + @test reranked.positions == [2] # gets resorted by score + @test reranked.scores == [0.4] # Cohere assertion - ci1 = ChunkIndex(id = :TestChunkIndex1, - chunks = ["chunk1", "chunk2"], - sources = ["source1", "source2"]) ci2 = ChunkIndex(id = :TestChunkIndex2, chunks = ["chunk1", "chunk2"], sources = ["source1", "source2"]) mi = MultiIndex(; id = :multi, indexes = [ci1, ci2]) - @test_throws ArgumentError rerank(CohereRerank(), + @test_throws ArgumentError rerank(CohereReranker(), mi, question, - candidate_chunks) + cc1) # Bad top_n - @test_throws AssertionError rerank(CohereRerank(), + @test_throws AssertionError rerank(CohereReranker(), ci1, question, - candidate_chunks; top_n = 0) + cc1; top_n = 0) # Bad index_id cc2 = CandidateChunks(index_id = :TestChunkIndex2, positions = [1, 2], - distances = [0.3, 0.4]) - @test_throws AssertionError rerank(CohereRerank(), + scores = [0.3, 0.4]) + @test_throws AssertionError rerank(CohereReranker(), ci1, question, cc2; top_n = 1) + + ## Unknown type + struct RandomReranker123 <: AbstractReranker end + @test_throws ArgumentError rerank(RandomReranker123(), ci1, "hello", cc2) + + ## TODO: add testing of Cohere reranker API call -- not done yet end + +@testset "retrieve" begin + # test with a mock server + PORT = rand(20000:40000) + PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-emb2", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) + + echo_server = HTTP.serve!(PORT; verbose = -1) do req + content = JSON3.read(req.body) + + if content[:model] == "mock-gen" + user_msg = last(content[:messages]) + response = Dict( + :choices => [ + Dict(:message => user_msg, :finish_reason => "stop") + ], + :model => content[:model], + :usage => Dict(:total_tokens => length(user_msg[:content]), + :prompt_tokens => length(user_msg[:content]), + :completion_tokens => 0)) + elseif content[:model] == "mock-emb" + response = Dict(:data => [Dict(:embedding => ones(Float32, 10))], + :usage => Dict(:total_tokens => length(content[:input]), + :prompt_tokens => length(content[:input]), + :completion_tokens => 0)) + elseif content[:model] == "mock-emb2" + response = Dict( + :data => [Dict(:embedding => ones(Float32, 10)), + Dict(:embedding => ones(Float32, 10))], + :usage => Dict(:total_tokens => length(content[:input]), + :prompt_tokens => length(content[:input]), + :completion_tokens => 0)) + elseif content[:model] == "mock-meta" + user_msg = last(content[:messages]) + response = Dict( + :choices => [ + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "category") + ]))))]))], + :model => content[:model], + :usage => Dict(:total_tokens => length(user_msg[:content]), + :prompt_tokens => length(user_msg[:content]), + :completion_tokens => 0)) + else + @info content + end + return HTTP.Response(200, JSON3.write(response)) + end + + embeddings1 = ones(Float32, 10, 4) + embeddings1[10, 3:4] .= 5.0 + embeddings1 = mapreduce(normalize, hcat, eachcol(embeddings1)) + index = ChunkIndex(id = :TestChunkIndex1, + chunks = ["chunk1", "chunk2", "chunk3", "chunk4"], + sources = ["source1", "source2", "source3", "source4"], + embeddings = embeddings1) + question = "test question" + + ## Test with SimpleRetriever + simple = SimpleRetriever() + + result = retrieve(simple, index, question; + rephraser_kwargs = (; model = "mock-gen"), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test result.question == question + @test result.rephrased_questions == [question] + @test result.answer == nothing + @test result.final_answer == nothing + @test result.reranked_candidates.positions == [2, 1, 4, 3] + @test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"] + @test result.sources isa Vector{String} + + # Reduce number of candidates + result = retrieve(simple, index, question; + top_n = 2, top_k = 3, + rephraser_kwargs = (; model = "mock-gen"), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test result.emb_candidates.positions == [2, 1, 4] + @test result.reranked_candidates.positions == [2, 1] + + # with default dispatch + result = retrieve(index, question; + top_n = 2, top_k = 3, + rephraser_kwargs = (; model = "mock-gen"), + embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test result.emb_candidates.positions == [2, 1, 4] + @test result.reranked_candidates.positions == [2, 1] + + ## AdvancedRetriever + adv = AdvancedRetriever() + result = retrieve(adv, index, question; + reranker = NoReranker(), # we need to disable cohere as we cannot test it + rephraser_kwargs = (; model = "mock-gen"), + embedder_kwargs = (; model = "mock-emb2"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test result.question == question + @test result.rephrased_questions == [question, "Query: test question\n\nPassage:"] # from the template we use + @test result.answer == nothing + @test result.final_answer == nothing + @test result.reranked_candidates.positions == [2, 1, 4, 3] + @test result.context == ["chunk2", "chunk1", "chunk4", "chunk3"] + @test result.sources isa Vector{String} + + # clean up + close(echo_server) +end \ No newline at end of file diff --git a/test/Experimental/RAGTools/runtests.jl b/test/Experimental/RAGTools/runtests.jl index 53841259..660bcc40 100644 --- a/test/Experimental/RAGTools/runtests.jl +++ b/test/Experimental/RAGTools/runtests.jl @@ -2,7 +2,7 @@ using Test using SparseArrays, LinearAlgebra using PromptingTools.Experimental.RAGTools using PromptingTools -using AbstractTrees +using PromptingTools.AbstractTrees const PT = PromptingTools using JSON3, HTTP diff --git a/test/Experimental/RAGTools/types.jl b/test/Experimental/RAGTools/types.jl index 8cad2cd6..a7ac73e0 100644 --- a/test/Experimental/RAGTools/types.jl +++ b/test/Experimental/RAGTools/types.jl @@ -1,5 +1,7 @@ -using PromptingTools.Experimental.RAGTools: ChunkIndex, MultiIndex, CandidateChunks -using PromptingTools.Experimental.RAGTools: embeddings, chunks, tags, tags_vocab, sources +using PromptingTools.Experimental.RAGTools: ChunkIndex, MultiIndex, CandidateChunks, + AbstractCandidateChunks +using PromptingTools.Experimental.RAGTools: embeddings, chunks, tags, tags_vocab, sources, + RAGResult @testset "ChunkIndex" begin # Test constructors and basic accessors @@ -101,20 +103,63 @@ end chunk_sym = Symbol("TestChunkIndex") cc1 = CandidateChunks(index_id = chunk_sym, positions = [1, 3], - distances = [0.1, 0.2]) + scores = [0.1, 0.2]) @test Base.length(cc1) == 2 + out = Base.first(cc1, 1) + @test out.positions == [3] + @test out.scores == [0.2] # Test intersection & cc2 = CandidateChunks(index_id = chunk_sym, positions = [2, 4], - distances = [0.3, 0.4]) + scores = [0.3, 0.4]) @test isempty((cc1 & cc2).positions) cc3 = CandidateChunks(index_id = chunk_sym, positions = [1, 4], - distances = [0.3, 0.4]) + scores = [0.3, 0.5]) joint = (cc1 & cc3) @test joint.positions == [1] - @test joint.distances == [0.2] + @test joint.scores == [0.3] + joint2 = (cc2 & cc3) + @test joint2.positions == [4] + @test joint2.scores == [0.5] + + # long positions intersection + cc5 = CandidateChunks(index_id = chunk_sym, + positions = [5, 6, 7, 8, 9, 10, 4], + scores = 0.1 * ones(7)) + joint5 = (cc2 & cc5) + @test joint5.positions == [4] + @test joint5.scores == [0.4] + + # wrong index + cc4 = CandidateChunks(index_id = :xyz, + positions = [2, 4], + scores = [0.3, 0.4]) + joint4 = (cc2 & cc4) + @test isempty(joint4.positions) + @test isempty(joint4.scores) + @test isempty(joint4) == true + + # Test unknown type + struct RandomCandidateChunks123 <: AbstractCandidateChunks end + @test_throws ArgumentError (cc1&RandomCandidateChunks123()) + + # Test vcat + vcat1 = vcat(cc1, cc2) + @test Base.length(vcat1) == 4 + vcat2 = vcat(cc1, cc3) + @test vcat2.positions == [4, 1, 3] + @test vcat2.scores == [0.5, 0.3, 0.2] + # wrong index + @test_throws ArgumentError vcat(cc1, cc4) + # uknown type + @test_throws ArgumentError vcat(cc1, RandomCandidateChunks123()) + + # Test copy + cc1_copy = copy(cc1) + @test cc1 == cc1_copy + @test cc1.positions !== cc1_copy.positions # not the same array end @testset "getindex with CandidateChunks" begin @@ -134,7 +179,7 @@ end # Test to get chunks based on valid CandidateChunks candidate_chunks = CandidateChunks(index_id = chunk_sym, positions = [1, 3], - distances = [0.1, 0.2]) + scores = [0.1, 0.2]) @test collect(test_chunk_index[candidate_chunks]) == ["First chunk", "Third chunk"] @test collect(test_chunk_index[candidate_chunks, :chunks]) == ["First chunk", "Third chunk"] @@ -146,7 +191,7 @@ end # Test with empty positions, which should result in an empty array candidate_chunks_empty = CandidateChunks(index_id = chunk_sym, positions = Int[], - distances = Float32[]) + scores = Float32[]) @test isempty(test_chunk_index[candidate_chunks_empty]) @test isempty(test_chunk_index[candidate_chunks_empty, :chunks]) @test isempty(test_chunk_index[candidate_chunks_empty, :embeddings]) @@ -155,14 +200,14 @@ end # Test with positions out of bounds, should handle gracefully without errors candidate_chunks_oob = CandidateChunks(index_id = chunk_sym, positions = [10, -1], - distances = [0.5, 0.6]) + scores = [0.5, 0.6]) @test_throws AssertionError test_chunk_index[candidate_chunks_oob] # Test with an incorrect index_id, which should also result in an empty array wrong_sym = Symbol("InvalidIndex") candidate_chunks_wrong_id = CandidateChunks(index_id = wrong_sym, positions = [1, 2], - distances = [0.3, 0.4]) + scores = [0.3, 0.4]) @test isempty(test_chunk_index[candidate_chunks_wrong_id]) # Test when chunks are requested from a MultiIndex, only chunks from the corresponding ChunkIndex should be returned @@ -187,11 +232,11 @@ end # Multi-Candidate CandidateChunks cc1 = CandidateChunks(index_id = :TestChunkIndex1, positions = [1, 2], - distances = [0.3, 0.4]) + scores = [0.3, 0.4]) cc2 = CandidateChunks(index_id = :TestChunkIndex2, positions = [2], - distances = [0.1]) - cc = CandidateChunks(; index_id = :multi, positions = [cc1, cc2], distances = zeros(2)) + scores = [0.1]) + cc = CandidateChunks(; index_id = :multi, positions = [cc1, cc2], scores = zeros(2)) ci1 = ChunkIndex(id = :TestChunkIndex1, chunks = ["chunk1", "chunk2"], sources = ["source1", "source2"]) @@ -205,3 +250,31 @@ end mi = MultiIndex(; id = :multi, indexes = [ci1, ci2]) @test mi[cc] == ["chunk1", "chunk2", "chunk2"] end + +@testset "RAGResult" begin + result = RAGResult(; question = "a", answer = "b", final_answer = "c") + result2 = RAGResult(; question = "a", answer = "b", final_answer = "c") + @test result == result2 + + result3 = copy(result) + @test result == result3 + @test result !== result3 + + ## pprint checks - empty context fails + io = IOBuffer() + @test_throws AssertionError PT.pprint(io, result) + + ## RAG Details dispatch + answer = "This is a test answer." + sources_ = ["Source 1", "Source 2", "Source 3"] + result = RAGResult(; + question = "?", final_answer = answer, context = sources_, sources = sources_) + io = IOBuffer() + PT.pprint(io, result; add_context = true) + output = String(take!(io)) + @test occursin("This is a test answer.", output) + @test occursin("\nQUESTION", output) + @test occursin("\nSOURCES\n", output) + @test occursin("\nCONTEXT\n", output) + @test occursin("1. Source 1", output) +end diff --git a/test/Experimental/RAGTools/utils.jl b/test/Experimental/RAGTools/utils.jl index ca00dff5..346573aa 100644 --- a/test/Experimental/RAGTools/utils.jl +++ b/test/Experimental/RAGTools/utils.jl @@ -247,11 +247,10 @@ end ``` and `inline code`.""" sentences, group_ids = split_into_code_and_sentences(input) - sentences - @test sentences == ["Here is a code block: \n", "```julia", "\n", + @test sentences == ["Here is a code block: ", "\n", "```julia", "\n", "code here", "\n", "```", "\n", "and ", "`inline code`", "."] @test join(sentences, "") == input - @test group_ids == [1, 2, 2, 2, 2, 2, 3, 4, 5, 6] + @test group_ids == [1, 2, 3, 3, 3, 3, 3, 4, 5, 6, 7] ## Multi-faceted code input = """Here is a code block: @@ -271,11 +270,12 @@ end """ sentences, group_ids = split_into_code_and_sentences(input) @test sentences == - ["Here is a code block: \n", "```julia", "\n", "code here", "\n", "```", "\n", - "and ", "`inline code`", ".", "\n", "Sentences here.\n", "Bullets:\n-", - " I like this\n-", " But does it work?\n", "```julia", "\n", "another code", - "\n", "```", "\n", "1. ", "Tester\n", "Third sentence -", " but what happened.\n"] + [ + "Here is a code block: ", "\n", "```julia", "\n", "code here", "\n", "```", "\n", + "and ", "`inline code`", ".", "\n", "Sentences here.", "\n", "Bullets:", "\n", "- ", + "I like this", "\n", "- ", "But does it work?", "\n", "```julia", "\n", "another code", + "\n", "```", "\n", "1. ", "Tester", "\n", "Third sentence - but what happened.", "\n"] @test join(sentences, "") == input - @test group_ids == [ - 1, 2, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12, 12, 13, 14, 15, 16, 17] + @test group_ids == [1, 2, 3, 3, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 19, 19, 19, 19, 20, 21, 22, 23, 24, 25] end \ No newline at end of file diff --git a/test/llm_ollama.jl b/test/llm_ollama.jl index d3092f4e..e486f0d2 100644 --- a/test/llm_ollama.jl +++ b/test/llm_ollama.jl @@ -15,37 +15,37 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, _encode_l messages = [UserMessage("I am {{name}}")] expected_output = [ Dict("role" => "system", "content" => "Act as a helpful AI assistant"), - Dict("role" => "user", "content" => "I am John Doe"), + Dict("role" => "user", "content" => "I am John Doe") ] @test render(schema, messages; name = "John Doe") == expected_output # Test message rendering with system and user messages messages = [ SystemMessage("This is a system generated message."), - UserMessage("A user generated reply."), + UserMessage("A user generated reply.") ] expected_output = [ Dict("role" => "system", "content" => "This is a system generated message."), - Dict("role" => "user", "content" => "A user generated reply."), + Dict("role" => "user", "content" => "A user generated reply.") ] @test render(schema, messages) == expected_output # Test message rendering with images messages = [ UserMessageWithImages("User message with an image"; - image_url = ["https://example.com/image.jpg"]), + image_url = ["https://example.com/image.jpg"]) ] expected_output = [ Dict("role" => "system", "content" => "Act as a helpful AI assistant"), Dict("role" => "user", "content" => "User message with an image", - "images" => ["https://example.com/image.jpg"]), + "images" => ["https://example.com/image.jpg"]) ] @test render(schema, messages) == expected_output # Test message with local image messages = [ UserMessageWithImages("User message with an image"; - image_path = joinpath(@__DIR__, "data", "julia.png"), base64_only = true), + image_path = joinpath(@__DIR__, "data", "julia.png"), base64_only = true) ] raw_img = _encode_local_image(joinpath(@__DIR__, "data", "julia.png"); base64_only = true) @@ -53,7 +53,7 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, _encode_l Dict("role" => "system", "content" => "Act as a helpful AI assistant"), Dict("role" => "user", "content" => "User message with an image", - "images" => [raw_img]), + "images" => [raw_img]) ] @test render(schema, messages) == expected_output end @@ -103,7 +103,7 @@ end conversation = [SystemMessage("Today's weather is {{weather}}.")] # Mock dry run replacing the template variable expected_convo_output = [ - SystemMessage(; content = "Today's weather is sunny.", variables = [:weather]), + SystemMessage(; content = "Today's weather is sunny.", variables = [:weather]) ] @test aigenerate(schema, conversation; @@ -140,7 +140,7 @@ end "content" => "hi", "images" => [ _encode_local_image(joinpath(@__DIR__, "data", "julia.png"), - base64_only = true), + base64_only = true) ])] @test_throws AssertionError aiscan(schema, "hi";