Skip to content

Commit

Permalink
Move download datasets logic our of download_utils.py (run-llama#9253)
Browse files Browse the repository at this point in the history
* add modifications to allow for pdf download

* lint

* lint

* separate content vars

* move download datasets into separate sub module

* wip

* update commandline

* move util functions to utils

* use renamed module

* refactor to use new module name

* refactor to use new module name

* cr

* lint

* add missing import

* add missing import

* add missing import; add back download_llama_datasets

* fix command line

* get metadata from hub

* fix command line

* wip
  • Loading branch information
nerdai authored Dec 2, 2023
1 parent f49ac17 commit 1f9ba34
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 195 deletions.
21 changes: 15 additions & 6 deletions llama_index/command_line/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, Optional

from llama_index.llama_dataset.download import (
LLAMA_DATASETS_URL,
LLAMA_DATASETS_LFS_URL,
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
download_llama_dataset,
)
from llama_index.llama_pack.download import LLAMA_HUB_URL, download_llama_pack
Expand All @@ -29,7 +30,8 @@ def handle_download_llama_dataset(
llama_dataset_class: Optional[str] = None,
download_dir: Optional[str] = None,
llama_hub_url: str = LLAMA_HUB_URL,
llama_datasets_url: str = LLAMA_DATASETS_URL,
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
**kwargs: Any,
) -> None:
assert llama_dataset_class is not None
Expand All @@ -39,10 +41,11 @@ def handle_download_llama_dataset(
llama_dataset_class=llama_dataset_class,
download_dir=download_dir,
llama_hub_url=llama_hub_url,
llama_datasets_url=llama_datasets_url,
llama_datasets_lfs_url=llama_datasets_lfs_url,
llama_datasets_source_files_tree_url=llama_datasets_source_files_tree_url,
)

print(f"Successfully downloaded {llama_datasets_url} to {download_dir}")
print(f"Successfully downloaded {llama_dataset_class} to {download_dir}")


def main() -> None:
Expand Down Expand Up @@ -106,9 +109,15 @@ def main() -> None:
help="URL to llama hub.",
)
llamadataset_parser.add_argument(
"--llama-dataset-url",
"--llama-datasets-lfs-url",
type=str,
default=LLAMA_DATASETS_URL,
default=LLAMA_DATASETS_LFS_URL,
help="URL to llama datasets.",
)
llamadataset_parser.add_argument(
"--llama-datasets-lfs-url",
type=str,
default=LLAMA_DATASETS_LFS_URL,
help="URL to llama datasets.",
)
llamadataset_parser.set_defaults(
Expand Down
228 changes: 228 additions & 0 deletions llama_index/download/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Download."""

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import requests
import tqdm

from llama_index.download.module import LLAMA_HUB_URL
from llama_index.download.utils import (
get_file_content,
get_file_content_bytes,
initialize_directory,
)

LLAMA_DATASETS_LFS_URL = (
f"https://media.githubusercontent.com/media/run-llama/llama_datasets/main"
)

LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (
"https://github.com/run-llama/llama_datasets/tree/main"
)
LLAMA_RAG_DATASET_FILENAME = "rag_dataset.json"
LLAMA_SOURCE_FILES_PATH = "source_files"


PATH_TYPE = Union[str, Path]


def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:
"""Get the list of source files to download."""
resp = requests.get(source_tree_url + path + "?recursive=1")
payload = resp.json()["payload"]
return [item["name"] for item in payload["tree"]["items"]]


def get_dataset_info(
local_dir_path: PATH_TYPE,
remote_dir_path: PATH_TYPE,
remote_source_dir_path: PATH_TYPE,
dataset_class: str,
refresh_cache: bool = False,
library_path: str = "library.json",
source_files_path: str = "source_files",
disable_library_cache: bool = False,
) -> Dict:
"""Get dataset info."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)

local_library_path = f"{local_dir_path}/{library_path}"
dataset_id = None
source_files = []

# Check cache first
if not refresh_cache and os.path.exists(local_library_path):
with open(local_library_path) as f:
library = json.load(f)
if dataset_class in library:
dataset_id = library[dataset_class]["id"]
source_files = library[dataset_class].get("source_files", [])

# Fetch up-to-date library from remote repo if dataset_id not found
if dataset_id is None:
library_raw_content, _ = get_file_content(
str(remote_dir_path), f"/{library_path}"
)
library = json.loads(library_raw_content)
if dataset_class not in library:
raise ValueError("Loader class name not found in library")

dataset_id = library[dataset_class]["id"]
source_files = _get_source_files_list(
str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"
)

# create cache dir if needed
local_library_dir = os.path.dirname(local_library_path)
if not disable_library_cache:
if not os.path.exists(local_library_dir):
os.makedirs(local_library_dir)

# Update cache
with open(local_library_path, "w") as f:
f.write(library_raw_content)

if dataset_id is None:
raise ValueError("Dataset class name not found in library")

return {
"dataset_id": dataset_id,
"source_files": source_files,
}


def download_dataset_and_source_files(
local_dir_path: PATH_TYPE,
remote_lfs_dir_path: PATH_TYPE,
source_files_dir_path: PATH_TYPE,
dataset_id: str,
source_files: List[str],
refresh_cache: bool = False,
base_file_name: str = "rag_dataset.json",
override_path: bool = False,
show_progress: bool = False,
) -> None:
"""Download dataset and source files."""
if isinstance(local_dir_path, str):
local_dir_path = Path(local_dir_path)

if override_path:
module_path = str(local_dir_path)
else:
module_path = f"{local_dir_path}/{dataset_id}"

if refresh_cache or not os.path.exists(module_path):
os.makedirs(module_path, exist_ok=True)
os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)

rag_dataset_raw_content, _ = get_file_content(
str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"
)

with open(f"{module_path}/{base_file_name}", "w") as f:
f.write(rag_dataset_raw_content)

# Get content of source files
if show_progress:
source_files_iterator = tqdm.tqdm(source_files)
else:
source_files_iterator = source_files
for source_file in source_files_iterator:
if ".pdf" in source_file:
source_file_raw_content_bytes, _ = get_file_content_bytes(
str(remote_lfs_dir_path),
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
)
with open(
f"{module_path}/{source_files_dir_path}/{source_file}", "wb"
) as f:
f.write(source_file_raw_content_bytes)
else:
source_file_raw_content, _ = get_file_content(
str(remote_lfs_dir_path),
f"/{dataset_id}/{source_files_dir_path}/{source_file}",
)
with open(
f"{module_path}/{source_files_dir_path}/{source_file}", "w"
) as f:
f.write(source_file_raw_content)


def download_llama_dataset(
dataset_class: str,
llama_hub_url: str = LLAMA_HUB_URL,
llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
refresh_cache: bool = False,
custom_dir: Optional[str] = None,
custom_path: Optional[str] = None,
source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,
library_path: str = "library.json",
base_file_name: str = "rag_dataset.json",
disable_library_cache: bool = False,
override_path: bool = False,
show_progress: bool = False,
) -> Any:
"""Download a module from LlamaHub.
Can be a loader, tool, pack, or more.
Args:
loader_class: The name of the llama module class you want to download,
such as `GmailOpenAIAgentPack`.
refresh_cache: If true, the local cache will be skipped and the
loader will be fetched directly from the remote repo.
custom_dir: Custom dir name to download loader into (under parent folder).
custom_path: Custom dirpath to download loader into.
library_path: File name of the library file.
use_gpt_index_import: If true, the loader files will use
llama_index as the base dependency. By default (False),
the loader files use llama_index as the base dependency.
NOTE: this is a temporary workaround while we fully migrate all usages
to llama_index.
is_dataset: whether or not downloading a LlamaDataset
Returns:
A Loader, A Pack, An Agent, or A Dataset
"""
# create directory / get path
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)

# fetch info from library.json file
dataset_info = get_dataset_info(
local_dir_path=dirpath,
remote_dir_path=llama_hub_url,
remote_source_dir_path=llama_datasets_source_files_tree_url,
dataset_class=dataset_class,
refresh_cache=refresh_cache,
library_path=library_path,
disable_library_cache=disable_library_cache,
)
dataset_id = dataset_info["dataset_id"]
source_files = dataset_info["source_files"]

download_dataset_and_source_files(
local_dir_path=dirpath,
remote_lfs_dir_path=llama_datasets_lfs_url,
source_files_dir_path=source_files_dirpath,
dataset_id=dataset_id,
source_files=source_files,
refresh_cache=refresh_cache,
base_file_name=base_file_name,
override_path=override_path,
show_progress=show_progress,
)

if override_path:
module_path = str(dirpath)
else:
module_path = f"{dirpath}/{dataset_id}"

return (
f"{module_path}/{LLAMA_RAG_DATASET_FILENAME}",
f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",
)
Loading

0 comments on commit 1f9ba34

Please sign in to comment.