Skip to content

Commit

Permalink
Merge pull request #25 from epinzur/crag-datasets
Browse files Browse the repository at this point in the history
added support for crag task_1 dataset
  • Loading branch information
epinzur authored Jun 25, 2024
2 parents 8e1d263 + 0a706ff commit 882b16f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 9 deletions.
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pyyaml = "^6.0.1"
cerberus = "^1.3.5"
pydantic = "^2.7.3"
setuptools = "^70.0.0"
aiofiles = "^24.1.0"

[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
Expand Down
8 changes: 3 additions & 5 deletions ragulate/cli_commands/download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ragulate.datasets import LlamaDataset
from ragulate.datasets import get_dataset


def setup_download(subparsers):
Expand All @@ -22,7 +22,5 @@ def setup_download(subparsers):


def call_download(dataset_name: str, kind: str, **kwargs):
if not kind == "llama":
raise ("Currently only Llama Datasets are supported. Set param `-k llama`")
llama = LlamaDataset(dataset_name=dataset_name)
llama.download_dataset()
dataset = get_dataset(name=dataset_name, kind=kind)
dataset.download_dataset()
19 changes: 19 additions & 0 deletions ragulate/cli_commands/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def setup_query(subparsers):
help=("The name of a dataset to query", "This can be passed multiple times."),
action="append",
)
query_parser.add_argument(
"--subset",
type=str,
help=(
"The subset of the dataset to query",
"Only valid when a single dataset is passed.",
),
action="append",
)
query_parser.set_defaults(func=lambda args: call_query(**vars(args)))

def call_query(
Expand All @@ -64,10 +73,20 @@ def call_query(
var_name: List[str],
var_value: List[str],
dataset: List[str],
subset: List[str],
**kwargs,
):

datasets = [find_dataset(name=name) for name in dataset]

if len(subset) > 0:
if len(datasets) > 1:
raise ValueError(
"Only can set `subset` param when there is one dataset"
)
else:
datasets[0].subsets = subset

ingredients = convert_vars_to_ingredients(
var_names=var_name, var_values=var_value
)
Expand Down
2 changes: 2 additions & 0 deletions ragulate/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base_dataset import BaseDataset
from .crag_dataset import CragDataset
from .llama_dataset import LlamaDataset
from .utils import find_dataset, get_dataset

__all__ = [
"BaseDataset",
"CragDataset",
"LlamaDataset",
"find_dataset",
"get_dataset",
Expand Down
50 changes: 49 additions & 1 deletion ragulate/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import bz2
import tempfile
from abc import ABC, abstractmethod
from os import path
from os import makedirs, path
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import aiofiles
import aiohttp
from tqdm.asyncio import tqdm


class BaseDataset(ABC):

root_storage_path: str
name: str
_subsets: List[str] = []

def __init__(
self, dataset_name: str, root_storage_path: Optional[str] = "datasets"
Expand All @@ -27,6 +34,14 @@ def list_files_at_path(self, path: str) -> List[str]:
if f.is_file() and not f.name.startswith(".")
]

@property
def subsets(self) -> List[str]:
return self._subsets

@subsets.setter
def subsets(self, value: List[str]):
self._subsets = value

@abstractmethod
def sub_storage_path(self) -> str:
"""the sub-path to store the dataset in"""
Expand All @@ -42,3 +57,36 @@ def get_source_file_paths(self) -> List[str]:
@abstractmethod
def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
"""gets a list of queries and golden_truth answers for a dataset"""

async def _download_file(self, session, url, temp_file_path):
async with session.get(url) as response:
file_size = int(response.headers.get('Content-Length', 0))
chunk_size = 1024
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f'Downloading {url.split("/")[-1]}') as progress_bar:
async with aiofiles.open(temp_file_path, 'wb') as temp_file:
async for chunk in response.content.iter_chunked(chunk_size):
await temp_file.write(chunk)
progress_bar.update(len(chunk))

async def _decompress_file(self, temp_file_path, output_file_path):
makedirs(path.dirname(output_file_path), exist_ok=True)
with open(temp_file_path, 'rb') as temp_file:
decompressed_size = 0
with bz2.BZ2File(temp_file, 'rb') as bz2_file:
async with aiofiles.open(output_file_path, 'wb') as output_file:
with tqdm(unit='B', unit_scale=True, desc=f'Decompressing {output_file_path}') as progress_bar:
while True:
chunk = bz2_file.read(1024)
if not chunk:
break
await output_file.write(chunk)
decompressed_size += len(chunk)
progress_bar.update(len(chunk))

async def _download_and_decompress(self, url, output_file_path):
async with aiohttp.ClientSession() as session:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file_path = temp_file.name

await self._download_file(session, url, temp_file_path)
await self._decompress_file(temp_file_path, output_file_path)
19 changes: 17 additions & 2 deletions ragulate/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import os
from typing import List

from .base_dataset import BaseDataset
from .crag_dataset import CragDataset
from .llama_dataset import LlamaDataset


# TODO: implement this when adding additional dataset kinds
def find_dataset(name:str) -> BaseDataset:
root_path = "datasets"
name = name.lower()
for kind in os.listdir(root_path):
kind_path = os.path.join(root_path, kind)
if os.path.isdir(kind_path):
for dataset in os.listdir(kind_path):
dataset_path = os.path.join(kind_path, dataset)
if os.path.isdir(dataset_path):
if dataset.lower() == name:
return get_dataset(name, kind)

""" searches for a downloaded dataset with this name. if found, returns it."""
return get_dataset(name, "llama")

def get_dataset(name:str, kind:str) -> BaseDataset:
kind = kind.lower()
if kind == "llama":
return LlamaDataset(dataset_name=name)
elif kind == "crag":
return CragDataset(dataset_name=name)

raise NotImplementedError("only llama datasets are currently supported")
raise NotImplementedError("only llama and crag datasets are currently supported")

0 comments on commit 882b16f

Please sign in to comment.