Skip to content

Commit

Permalink
Merge pull request #38 from BetterAndBetterII/fix_build_bug
Browse files Browse the repository at this point in the history
修复pipeline 无法从GitHub批量构建索引的bug,增加qwen-max等模型名字匹配
  • Loading branch information
wxywb authored Feb 6, 2024
2 parents 1be52dc + 1a0ef7f commit fa38cf0
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,37 @@ def is_github_folder_url(url):
return url.startswith('https://raw.githubusercontent.com/') and '.' not in os.path.basename(url)


def get_github_repo_contents(repo_url):
response = requests.get(repo_url)
def get_branch_head_sha(owner, repo, branch):
url = f"https://api.github.com/repos/{owner}/{repo}/git/ref/heads/{branch}"
response = requests.get(url)
data = response.json()
sha = data['object']['sha']
return sha

filenames = []
if response.status_code == 200:
contents = response.json()
for item in contents['payload']['tree']['items']:
filenames.append(item['name'])
else:
print(f"Failed to fetch contents. Status code: {response.status_code}")
return filenames
def get_github_repo_contents(repo_url):
# repo_url example: https://raw.githubusercontent.com/wxywb/history_rag/master/data/history_24/
repo_owner = repo_url.split('/')[3]
repo_name = repo_url.split('/')[4]
branch = repo_url.split('/')[5]
folder_path = '/'.join(repo_url.split('/')[6:])
sha = get_branch_head_sha(repo_owner, repo_name, branch)
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/trees/{sha}?recursive=1"
try:
response = requests.get(url)
if response.status_code == 200:
data = response.json()

raw_urls = []
for file in data['tree']:
if file['path'].startswith(folder_path) and file['path'].endswith('.txt'):
raw_url = f"https://raw.githubusercontent.com/{repo_owner}/{repo_name}/{branch}/{file['path']}"
raw_urls.append(raw_url)
return raw_urls
else:
print(f"Failed to fetch contents. Status code: {response.status_code}")
except Exception as e:
print(f"Failed to fetch contents. Error: {str(e)}")
return []

class Executor:
def __init__(self, model):
Expand Down Expand Up @@ -109,7 +129,7 @@ def __init__(self, config):
embed_model = HuggingFaceEmbedding(model_name=config.embedding.name)

# 使用Qwen 通义千问模型
if config.llm.name == "qwen":
if config.llm.name.find("qwen") != -1:
llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048)
elif config.llm.name.find("gemini") != -1:
llm = Gemini(temperature=config.llm.temperature, model_name=config.llm.name, max_tokens=2048)
Expand Down Expand Up @@ -236,7 +256,7 @@ def __init__(self, config):
self.config = config
self._debug = False

if config.llm.name == "qwen":
if config.llm.name.find("qwen") != -1:
llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048)
elif config.llm.name.find("gemini") != -1:
llm = Gemini(model_name=config.llm.name, temperature=config.llm.temperature, max_tokens=2048)
Expand Down Expand Up @@ -296,10 +316,10 @@ def build_index(self, path, overwrite):
self._initialize_pipeline(self.service_context)

if is_github_folder_url(path):
filenames = get_github_repo_contents(path)
for filename in filenames:
if filename.endswith('txt'):
self.build_index(self, path + f'/{filename}')
urls = get_github_repo_contents(path)
for url in urls:
print(f'(rag) 正在构建索引 {url}')
self.build_index(url, False) # already deleted original collection
elif path.endswith('.txt'):
self.index.insert_doc_url(
url=path,
Expand Down

0 comments on commit fa38cf0

Please sign in to comment.