diff --git a/executor.py b/executor.py index 060e8d85..d22f50e7 100644 --- a/executor.py +++ b/executor.py @@ -66,17 +66,37 @@ def is_github_folder_url(url): return url.startswith('https://raw.githubusercontent.com/') and '.' not in os.path.basename(url) -def get_github_repo_contents(repo_url): - response = requests.get(repo_url) +def get_branch_head_sha(owner, repo, branch): + url = f"https://api.github.com/repos/{owner}/{repo}/git/ref/heads/{branch}" + response = requests.get(url) + data = response.json() + sha = data['object']['sha'] + return sha - filenames = [] - if response.status_code == 200: - contents = response.json() - for item in contents['payload']['tree']['items']: - filenames.append(item['name']) - else: - print(f"Failed to fetch contents. Status code: {response.status_code}") - return filenames +def get_github_repo_contents(repo_url): + # repo_url example: https://raw.githubusercontent.com/wxywb/history_rag/master/data/history_24/ + repo_owner = repo_url.split('/')[3] + repo_name = repo_url.split('/')[4] + branch = repo_url.split('/')[5] + folder_path = '/'.join(repo_url.split('/')[6:]) + sha = get_branch_head_sha(repo_owner, repo_name, branch) + url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/trees/{sha}?recursive=1" + try: + response = requests.get(url) + if response.status_code == 200: + data = response.json() + + raw_urls = [] + for file in data['tree']: + if file['path'].startswith(folder_path) and file['path'].endswith('.txt'): + raw_url = f"https://raw.githubusercontent.com/{repo_owner}/{repo_name}/{branch}/{file['path']}" + raw_urls.append(raw_url) + return raw_urls + else: + print(f"Failed to fetch contents. Status code: {response.status_code}") + except Exception as e: + print(f"Failed to fetch contents. Error: {str(e)}") + return [] class Executor: def __init__(self, model): @@ -109,7 +129,7 @@ def __init__(self, config): embed_model = HuggingFaceEmbedding(model_name=config.embedding.name) # 使用Qwen 通义千问模型 - if config.llm.name == "qwen": + if config.llm.name.find("qwen") != -1: llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048) elif config.llm.name.find("gemini") != -1: llm = Gemini(temperature=config.llm.temperature, model_name=config.llm.name, max_tokens=2048) @@ -236,7 +256,7 @@ def __init__(self, config): self.config = config self._debug = False - if config.llm.name == "qwen": + if config.llm.name.find("qwen") != -1: llm = QwenUnofficial(temperature=config.llm.temperature, model=config.llm.name, max_tokens=2048) elif config.llm.name.find("gemini") != -1: llm = Gemini(model_name=config.llm.name, temperature=config.llm.temperature, max_tokens=2048) @@ -296,10 +316,10 @@ def build_index(self, path, overwrite): self._initialize_pipeline(self.service_context) if is_github_folder_url(path): - filenames = get_github_repo_contents(path) - for filename in filenames: - if filename.endswith('txt'): - self.build_index(self, path + f'/{filename}') + urls = get_github_repo_contents(path) + for url in urls: + print(f'(rag) 正在构建索引 {url}') + self.build_index(url, False) # already deleted original collection elif path.endswith('.txt'): self.index.insert_doc_url( url=path,