forked from run-llama/llama_index
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move download datasets logic our of download_utils.py (run-llama#9253)
* add modifications to allow for pdf download * lint * lint * separate content vars * move download datasets into separate sub module * wip * update commandline * move util functions to utils * use renamed module * refactor to use new module name * refactor to use new module name * cr * lint * add missing import * add missing import * add missing import; add back download_llama_datasets * fix command line * get metadata from hub * fix command line * wip
- Loading branch information
Showing
8 changed files
with
433 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
"""Download.""" | ||
|
||
import json | ||
import os | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import requests | ||
import tqdm | ||
|
||
from llama_index.download.module import LLAMA_HUB_URL | ||
from llama_index.download.utils import ( | ||
get_file_content, | ||
get_file_content_bytes, | ||
initialize_directory, | ||
) | ||
|
||
LLAMA_DATASETS_LFS_URL = ( | ||
f"https://media.githubusercontent.com/media/run-llama/llama_datasets/main" | ||
) | ||
|
||
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = ( | ||
"https://github.com/run-llama/llama_datasets/tree/main" | ||
) | ||
LLAMA_RAG_DATASET_FILENAME = "rag_dataset.json" | ||
LLAMA_SOURCE_FILES_PATH = "source_files" | ||
|
||
|
||
PATH_TYPE = Union[str, Path] | ||
|
||
|
||
def _get_source_files_list(source_tree_url: str, path: str) -> List[str]: | ||
"""Get the list of source files to download.""" | ||
resp = requests.get(source_tree_url + path + "?recursive=1") | ||
payload = resp.json()["payload"] | ||
return [item["name"] for item in payload["tree"]["items"]] | ||
|
||
|
||
def get_dataset_info( | ||
local_dir_path: PATH_TYPE, | ||
remote_dir_path: PATH_TYPE, | ||
remote_source_dir_path: PATH_TYPE, | ||
dataset_class: str, | ||
refresh_cache: bool = False, | ||
library_path: str = "library.json", | ||
source_files_path: str = "source_files", | ||
disable_library_cache: bool = False, | ||
) -> Dict: | ||
"""Get dataset info.""" | ||
if isinstance(local_dir_path, str): | ||
local_dir_path = Path(local_dir_path) | ||
|
||
local_library_path = f"{local_dir_path}/{library_path}" | ||
dataset_id = None | ||
source_files = [] | ||
|
||
# Check cache first | ||
if not refresh_cache and os.path.exists(local_library_path): | ||
with open(local_library_path) as f: | ||
library = json.load(f) | ||
if dataset_class in library: | ||
dataset_id = library[dataset_class]["id"] | ||
source_files = library[dataset_class].get("source_files", []) | ||
|
||
# Fetch up-to-date library from remote repo if dataset_id not found | ||
if dataset_id is None: | ||
library_raw_content, _ = get_file_content( | ||
str(remote_dir_path), f"/{library_path}" | ||
) | ||
library = json.loads(library_raw_content) | ||
if dataset_class not in library: | ||
raise ValueError("Loader class name not found in library") | ||
|
||
dataset_id = library[dataset_class]["id"] | ||
source_files = _get_source_files_list( | ||
str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}" | ||
) | ||
|
||
# create cache dir if needed | ||
local_library_dir = os.path.dirname(local_library_path) | ||
if not disable_library_cache: | ||
if not os.path.exists(local_library_dir): | ||
os.makedirs(local_library_dir) | ||
|
||
# Update cache | ||
with open(local_library_path, "w") as f: | ||
f.write(library_raw_content) | ||
|
||
if dataset_id is None: | ||
raise ValueError("Dataset class name not found in library") | ||
|
||
return { | ||
"dataset_id": dataset_id, | ||
"source_files": source_files, | ||
} | ||
|
||
|
||
def download_dataset_and_source_files( | ||
local_dir_path: PATH_TYPE, | ||
remote_lfs_dir_path: PATH_TYPE, | ||
source_files_dir_path: PATH_TYPE, | ||
dataset_id: str, | ||
source_files: List[str], | ||
refresh_cache: bool = False, | ||
base_file_name: str = "rag_dataset.json", | ||
override_path: bool = False, | ||
show_progress: bool = False, | ||
) -> None: | ||
"""Download dataset and source files.""" | ||
if isinstance(local_dir_path, str): | ||
local_dir_path = Path(local_dir_path) | ||
|
||
if override_path: | ||
module_path = str(local_dir_path) | ||
else: | ||
module_path = f"{local_dir_path}/{dataset_id}" | ||
|
||
if refresh_cache or not os.path.exists(module_path): | ||
os.makedirs(module_path, exist_ok=True) | ||
os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True) | ||
|
||
rag_dataset_raw_content, _ = get_file_content( | ||
str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}" | ||
) | ||
|
||
with open(f"{module_path}/{base_file_name}", "w") as f: | ||
f.write(rag_dataset_raw_content) | ||
|
||
# Get content of source files | ||
if show_progress: | ||
source_files_iterator = tqdm.tqdm(source_files) | ||
else: | ||
source_files_iterator = source_files | ||
for source_file in source_files_iterator: | ||
if ".pdf" in source_file: | ||
source_file_raw_content_bytes, _ = get_file_content_bytes( | ||
str(remote_lfs_dir_path), | ||
f"/{dataset_id}/{source_files_dir_path}/{source_file}", | ||
) | ||
with open( | ||
f"{module_path}/{source_files_dir_path}/{source_file}", "wb" | ||
) as f: | ||
f.write(source_file_raw_content_bytes) | ||
else: | ||
source_file_raw_content, _ = get_file_content( | ||
str(remote_lfs_dir_path), | ||
f"/{dataset_id}/{source_files_dir_path}/{source_file}", | ||
) | ||
with open( | ||
f"{module_path}/{source_files_dir_path}/{source_file}", "w" | ||
) as f: | ||
f.write(source_file_raw_content) | ||
|
||
|
||
def download_llama_dataset( | ||
dataset_class: str, | ||
llama_hub_url: str = LLAMA_HUB_URL, | ||
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL, | ||
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL, | ||
refresh_cache: bool = False, | ||
custom_dir: Optional[str] = None, | ||
custom_path: Optional[str] = None, | ||
source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH, | ||
library_path: str = "library.json", | ||
base_file_name: str = "rag_dataset.json", | ||
disable_library_cache: bool = False, | ||
override_path: bool = False, | ||
show_progress: bool = False, | ||
) -> Any: | ||
"""Download a module from LlamaHub. | ||
Can be a loader, tool, pack, or more. | ||
Args: | ||
loader_class: The name of the llama module class you want to download, | ||
such as `GmailOpenAIAgentPack`. | ||
refresh_cache: If true, the local cache will be skipped and the | ||
loader will be fetched directly from the remote repo. | ||
custom_dir: Custom dir name to download loader into (under parent folder). | ||
custom_path: Custom dirpath to download loader into. | ||
library_path: File name of the library file. | ||
use_gpt_index_import: If true, the loader files will use | ||
llama_index as the base dependency. By default (False), | ||
the loader files use llama_index as the base dependency. | ||
NOTE: this is a temporary workaround while we fully migrate all usages | ||
to llama_index. | ||
is_dataset: whether or not downloading a LlamaDataset | ||
Returns: | ||
A Loader, A Pack, An Agent, or A Dataset | ||
""" | ||
# create directory / get path | ||
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir) | ||
|
||
# fetch info from library.json file | ||
dataset_info = get_dataset_info( | ||
local_dir_path=dirpath, | ||
remote_dir_path=llama_hub_url, | ||
remote_source_dir_path=llama_datasets_source_files_tree_url, | ||
dataset_class=dataset_class, | ||
refresh_cache=refresh_cache, | ||
library_path=library_path, | ||
disable_library_cache=disable_library_cache, | ||
) | ||
dataset_id = dataset_info["dataset_id"] | ||
source_files = dataset_info["source_files"] | ||
|
||
download_dataset_and_source_files( | ||
local_dir_path=dirpath, | ||
remote_lfs_dir_path=llama_datasets_lfs_url, | ||
source_files_dir_path=source_files_dirpath, | ||
dataset_id=dataset_id, | ||
source_files=source_files, | ||
refresh_cache=refresh_cache, | ||
base_file_name=base_file_name, | ||
override_path=override_path, | ||
show_progress=show_progress, | ||
) | ||
|
||
if override_path: | ||
module_path = str(dirpath) | ||
else: | ||
module_path = f"{dirpath}/{dataset_id}" | ||
|
||
return ( | ||
f"{module_path}/{LLAMA_RAG_DATASET_FILENAME}", | ||
f"{module_path}/{LLAMA_SOURCE_FILES_PATH}", | ||
) |
Oops, something went wrong.