Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New RM] Add AzureAISearch #198

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ You could also install the source code which allows you to modify the behavior o
Currently, our package support:

- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components

:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**

Expand Down
12 changes: 8 additions & 4 deletions examples/storm_examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"""

import os

from argparse import ArgumentParser
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -72,6 +73,7 @@ def main(args):

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.

match args.retriever:
case 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
Expand All @@ -87,8 +89,10 @@ def main(args):
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
case 'azure_ai_search':
rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"')

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -113,7 +117,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand All @@ -138,4 +142,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
109 changes: 108 additions & 1 deletion knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .utils import WebPageHelper

from azure.core.credentials import AzureKeyCredential
dfusion-dev marked this conversation as resolved.
Show resolved Hide resolved
from azure.search.documents import SearchClient

class YouRM(dspy.Retrieve):
def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):
Expand Down Expand Up @@ -77,7 +79,6 @@ def forward(

return collected_results


class BingSearch(dspy.Retrieve):
def __init__(
self,
Expand Down Expand Up @@ -1093,3 +1094,109 @@ def forward(
collected_results.append(r)

return collected_results

class AzureAISearch(dspy.Retrieve):
"""Retrieve information from custom queries using Azure AI Search. General Documentation can be found at: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal. Python Documentation and examples can be found at https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python. Requires pip install azure-search-documents"""
dfusion-dev marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
azure_ai_search_api_key=None,
azure_ai_search_url=None,
azure_ai_search_index_name=None,
k=3,
is_valid_source: Callable = None
):
"""
Params:
azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query
"API key" section
azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service
azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal
k: Number of top results to retrieve.
is_valid_source: Optional function to filter valid sources.
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
"""
super().__init__(k=k)
if not azure_ai_search_api_key and not os.environ.get("AZURE_AI_SEARCH_API_KEY"):
raise RuntimeError(
"You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY"
)
elif azure_ai_search_api_key:
self.azure_ai_search_api_key = azure_ai_search_api_key
else:
self.azure_ai_search_api_key = os.environ["AZURE_AI_SEARCH_API_KEY"]

if not azure_ai_search_url and not os.environ.get("AZURE_AI_SEARCH_URL"):
raise RuntimeError(
"You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL"
)
elif azure_ai_search_url:
self.azure_ai_search_url = azure_ai_search_url
else:
self.azure_ai_search_url = os.environ["AZURE_AI_SEARCH_URL"]

if not azure_ai_search_index_name and not os.environ.get("AZURE_AI_SEARCH_INDEX_NAME"):
raise RuntimeError(
"You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME"
)
elif azure_ai_search_index_name:
self.azure_ai_search_index_name = azure_ai_search_index_name
else:
self.azure_ai_search_index_name = os.environ["AZURE_AI_SEARCH_INDEX_NAME"]

self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"AzureAISearch": usage}

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
"""Search with Azure Open AI for self.k top passages for query or queries

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.

Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

client = SearchClient(self.azure_ai_search_url,self.azure_ai_search_index_name,AzureKeyCredential(self.azure_ai_search_api_key))
for query in queries:
try:
# https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search
results = client.search(search_text=query, top=1)

for result in results:
document = {
"url": result['metadata_storage_path'],
"title": result['title'],
"description": "N/A",
"snippets": [result['chunk']]
}
collected_results.append(document)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results


3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ trafilatura
langchain-huggingface
qdrant-client
langchain-qdrant
numpy==1.26.4
numpy==1.26.4
dfusion-dev marked this conversation as resolved.
Show resolved Hide resolved
azure-search-documents==11.5.1
Loading