-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathknowledge_base.py
100 lines (77 loc) · 2.94 KB
/
knowledge_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Optional
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import UnstructuredURLLoader
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
import requests
import xml.etree.ElementTree as ET
from dotenv import load_dotenv
from loguru import logger
load_dotenv()
def extract_urls_from_sitemap(sitemap):
"""
Extract all URLs from a sitemap XML string.
Args:
sitemap_string (str): The sitemap XML string.
Returns:
A list of URLs extracted from the sitemap.
"""
# Parse the XML from the string
root = ET.fromstring(sitemap)
# Define the namespace for the sitemap XML
namespace = {"ns": "http://www.sitemaps.org/schemas/sitemap/0.9"}
# Find all <loc> elements under the <url> elements
urls = [
url.find("ns:loc", namespace).text for url in root.findall("ns:url", namespace)
]
# Return the list of URLs
return urls
class KnowledgeBase:
def __init__(
self,
sitemap_url: str,
chunk_size: int,
chunk_overlap: int,
pattern: Optional[str] = None,
):
logger.info("Building the knowledge base ...")
logger.info("Loading sitemap from {sitemap_url} ...", sitemap_url=sitemap_url)
sitemap = requests.get(sitemap_url).text
urls = extract_urls_from_sitemap(sitemap)
if pattern:
logger.info("Filtering URLs with pattern {pattern} ...", pattern=pattern)
urls = [x for x in urls if pattern in x]
logger.info("{n} URLs extracted", n=len(urls))
logger.info("Loading URLs content ...")
loader = UnstructuredURLLoader(urls)
data = loader.load()
logger.info("Splitting documents in chunks ...")
doc_splitter = CharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
docs = doc_splitter.split_documents(data)
logger.info("{n} chunks created", n=len(docs))
logger.info("Building the vector database ...")
embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_documents(docs, embeddings)
logger.info("Building the retrieval chain ...")
self.chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(),
chain_type="map_reduce",
retriever=docsearch.as_retriever(),
)
logger.info("Knowledge base created!")
def ask(self, query: str):
return self.chain({"question": query}, return_only_outputs=True)
if __name__ == "__main__":
# Build the knowledge base
kb = KnowledgeBase(
sitemap_url="https://nextjs.org/sitemap.xml",
pattern="docs/api-refe",
chunk_size=8000,
chunk_overlap=3000,
)
# Ask a question
res = kb.ask("How do I deploy my Next.js app?")