Skip to content

Commit

Permalink
Implement cascade task (#137)
Browse files Browse the repository at this point in the history
	•	Implement cascading task processing and generation
	•	Implement sub-path task addition
	•	Implement arbitrary task execution
	•	Implement automatic consumption of unstarted tasks
  • Loading branch information
xingwanying authored Jun 5, 2024
2 parents 79c20d6 + b317fbf commit 8e20ed1
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ next-env.d.ts
/server/.aws-sam/*
.aws-sam/*

# IDE
.idea
File renamed without changes.
43 changes: 25 additions & 18 deletions server/rag/retrieval.py → server/rag_helper/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import json
from typing import Any
from langchain_openai import OpenAIEmbeddings

from langchain_community.vectorstores import SupabaseVectorStore
from db.supabase.client import get_client
from langchain_openai import OpenAIEmbeddings

from data_class import GitDocConfig, GitIssueConfig, S3Config
from rag.github_file_loader import GithubFileLoader
from db.supabase.client import get_client
from rag_helper.github_file_loader import GithubFileLoader
from uilts.env import get_env_variable


supabase_url = get_env_variable("SUPABASE_URL")
supabase_key = get_env_variable("SUPABASE_SERVICE_KEY")
ACCESS_TOKEN=get_env_variable("GITHUB_TOKEN")
ACCESS_TOKEN = get_env_variable("GITHUB_TOKEN")

TABLE_NAME = "rag_docs"
QUERY_NAME = "match_rag_docs"
CHUNK_SIZE = 2000
CHUNK_OVERLAP = 20

TABLE_NAME="rag_docs"
QUERY_NAME="match_rag_docs"
CHUNK_SIZE=2000
CHUNK_OVERLAP=20

def convert_document_to_dict(document):
return document.page_content,
Expand All @@ -25,11 +26,11 @@ def convert_document_to_dict(document):
def init_retriever():
embeddings = OpenAIEmbeddings()
vector_store = SupabaseVectorStore(
embedding=embeddings,
client=get_client(),
table_name=TABLE_NAME,
query_name=QUERY_NAME,
chunk_size=CHUNK_SIZE,
embedding=embeddings,
client=get_client(),
table_name=TABLE_NAME,
query_name=QUERY_NAME,
chunk_size=CHUNK_SIZE,
)

return vector_store.as_retriever()
Expand All @@ -40,6 +41,7 @@ def init_s3_Loader(config: S3Config):
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
return loader


def init_github_issue_loader(config: GitIssueConfig):
from langchain_community.document_loaders import GitHubIssuesLoader

Expand All @@ -51,6 +53,8 @@ def init_github_issue_loader(config: GitIssueConfig):
state=config.state
)
return loader


def init_github_file_loader(config: GitDocConfig):
loader = GithubFileLoader(
repo=config.repo_name,
Expand All @@ -63,6 +67,7 @@ def init_github_file_loader(config: GitDocConfig):
)
return loader


def supabase_embedding(documents, **kwargs: Any):
from langchain_text_splitters import CharacterTextSplitter

Expand Down Expand Up @@ -90,7 +95,7 @@ def add_knowledge_by_issues(config: GitIssueConfig, ):
loader = init_github_issue_loader(config)
documents = loader.load()
store = supabase_embedding(documents, repo_name=config.repo_name)
if(store):
if (store):
return json.dumps({
"success": True,
"message": "Knowledge added successfully!",
Expand All @@ -106,6 +111,7 @@ def add_knowledge_by_issues(config: GitIssueConfig, ):
"message": str(e)
})


def add_knowledge_by_doc(config: GitDocConfig):
loader = init_github_file_loader(config)
documents = loader.load()
Expand All @@ -116,14 +122,14 @@ def add_knowledge_by_doc(config: GitDocConfig):
.eq('repo_name', config.repo_name)
.eq('commit_id', loader.commit_id)
.eq('file_path', config.file_path).execute()
)
if (is_added_query.data == []):
)
if not is_added_query.data:
is_equal_query = (
supabase.table(TABLE_NAME)
.select("*")
.eq('file_sha', loader.file_sha)
).execute()
if (is_equal_query.data == []):
if not is_equal_query.data:
store = supabase_embedding(documents,
repo_name=config.repo_name,
commit_id=loader.commit_id,
Expand All @@ -149,6 +155,7 @@ def add_knowledge_by_doc(config: GitDocConfig):
else:
return True


def search_knowledge(query: str):
retriever = init_retriever()
docs = retriever.invoke(query)
Expand Down
151 changes: 151 additions & 0 deletions server/rag_helper/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from enum import Enum, auto
from typing import Optional, Dict

from github import Auth, Repository
from github import Github
from langchain_core.utils import get_from_env

from data_class import GitDocConfig
from db.supabase.client import get_client
from rag_helper import retrieval

g = Github(auth=Auth.Token(get_from_env("access_token", 'GITHUB_TOKEN')))

TABLE_NAME = "rag_tasks"


class TaskStatus(Enum):
NOT_STARTED = auto()
IN_PROGRESS = auto()
COMPLETED = auto()
ON_HOLD = auto()
CANCELLED = auto()
ERROR = auto()


def add_task(config: GitDocConfig,
extra: Optional[Dict[str, Optional[str]]] = {"node_type": None, "from_task_id": None}):
repo = g.get_repo(config.repo_name)
commit_id = config.commit_id if config.commit_id else repo.get_branch(config.branch).commit.sha

if config.file_path == '' or config.file_path is None:
extra["node_type"] = 'tree'

if not extra.get("node_type"):
content = repo.get_contents(config.file_path, ref=commit_id)
if isinstance(content, list):
extra["node_type"] = 'tree'
else:
extra["node_type"] = 'blob'

sha = get_path_sha(repo, commit_id, config.file_path)

supabase = get_client()

data = {
"repo_name": config.repo_name,
"commit_id": commit_id,
"status": TaskStatus.NOT_STARTED.name,
"node_type": extra["node_type"],
"from_task_id": extra["from_task_id"],
"path": config.file_path,
"sha": sha
}

return supabase.table(TABLE_NAME).insert(data).execute()


def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None):
if not path:
return sha
else:
tree_data = repo.get_git_tree(sha)
for item in tree_data.tree:
if path.split("/")[0] == item.path:
return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:]))


def get_oldest_task():
supabase = get_client()

response = (supabase
.table(TABLE_NAME)
.select("*")
.eq("status", TaskStatus.NOT_STARTED.name)
.order("created_at", desc=False)
.limit(1)
.execute())

return response.data[0] if (len(response.data) > 0) else None


def get_task_by_id(task_id):
supabase = get_client()

response = (supabase
.table(TABLE_NAME)
.select("*")
.eq("id", task_id)
.execute())
return response.data[0] if (len(response.data) > 0) else None


def handle_tree_task(task):
supabase = get_client()
(supabase
.table(TABLE_NAME)
.update({"status": TaskStatus.IN_PROGRESS.name})
.eq('id', task["id"])
.execute()
)

repo = g.get_repo(task["repo_name"])
tree_data = repo.get_git_tree(task["sha"])

task_list = list(filter(lambda item: item["path"].endswith('.md') or item["node_type"] == 'tree', map(lambda item: {
"repo_name": task["repo_name"],
"commit_id": task["commit_id"],
"status": TaskStatus.NOT_STARTED.name,
"node_type": item.type,
"from_task_id": task["id"],
"path": "/".join(filter(lambda s: s, [task["path"], item.path])),
"sha": item.sha
}, tree_data.tree)))

if len(task_list) > 0:
supabase.table(TABLE_NAME).insert(task_list).execute()
return (supabase.table(TABLE_NAME).update(
{"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))},
"status": TaskStatus.COMPLETED.name})
.eq("id", task["id"])
.execute())


def handle_blob_task(task):
supabase = get_client()
(supabase
.table(TABLE_NAME)
.update({"status": TaskStatus.IN_PROGRESS.name})
.eq('id', task["id"])
.execute()
)

result = retrieval.add_knowledge_by_doc(GitDocConfig(repo_name=task["repo_name"],
file_path=task["path"],
commit_id=task["commit_id"]
))

return (supabase.table(TABLE_NAME).update(
{"status": TaskStatus.COMPLETED.name})
.eq("id", task["id"])
.execute())


def trigger_task(task_id: Optional[str]):
task = get_task_by_id(task_id) if task_id else get_oldest_task()
if task is None:
return task
if task['node_type'] == 'tree':
return handle_tree_task(task)
else:
return handle_blob_task(task)
31 changes: 27 additions & 4 deletions server/routers/rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
from typing import Optional

from fastapi import APIRouter, Depends
from rag import retrieval

from data_class import GitDocConfig, GitIssueConfig
from rag_helper import retrieval, task
from verify.rate_limit import verify_rate_limit

router = APIRouter(
Expand All @@ -14,7 +17,7 @@
@router.post("/rag/add_knowledge_by_doc", dependencies=[Depends(verify_rate_limit)])
def add_knowledge_by_doc(config: GitDocConfig):
try:
result=retrieval.add_knowledge_by_doc(config)
result = retrieval.add_knowledge_by_doc(config)
if (result):
return json.dumps({
"success": True,
Expand All @@ -31,12 +34,32 @@ def add_knowledge_by_doc(config: GitDocConfig):
"message": str(e)
})


@router.post("/rag/add_knowledge_by_issues", dependencies=[Depends(verify_rate_limit)])
def add_knowledge_by_issues(config: GitIssueConfig):
data=retrieval.add_knowledge_by_issues(config)
data = retrieval.add_knowledge_by_issues(config)
return data


@router.post("/rag/search_knowledge", dependencies=[Depends(verify_rate_limit)])
def search_knowledge(query: str):
data=retrieval.search_knowledge(query)
data = retrieval.search_knowledge(query)
return data


@router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)])
def add_task(config: GitDocConfig):
try:
data = task.add_task(config)
return data
except Exception as e:
return json.dumps({
"success": False,
"message": str(e)
})


@router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)])
def trigger_task(task_id: Optional[str] = None):
data = task.trigger_task(task_id)
return data
2 changes: 1 addition & 1 deletion server/tools/knowledge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain.tools import tool
from rag import retrieval
from rag_helper import retrieval


@tool
Expand Down

0 comments on commit 8e20ed1

Please sign in to comment.