Skip to content

Commit

Permalink
feat(data): get rid of dataset bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
ThePyProgrammer committed Aug 2, 2024
1 parent 6b40b9f commit 20218cc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 154 deletions.
4 changes: 2 additions & 2 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from walledeval.data import MultipleChoiceDataset
from walledeval.data import HuggingFaceDataset
from walledeval.types import MultipleChoiceQuestion

WMDP_BIO = None

def test_loading():
global WMDP_BIO
WMDP_BIO = MultipleChoiceDataset.from_hub("cais/wmdp", "wmdp-bio", split="test")
WMDP_BIO = HuggingFaceDataset[MultipleChoiceQuestion].from_hub("cais/wmdp", "wmdp-bio", split="test")

assert WMDP_BIO.name == "cais/wmdp/wmdp-bio"

Expand Down
8 changes: 0 additions & 8 deletions walledeval/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
# walledeval/benchmark/__init__.py
from walledeval.data.core import (
Dataset, HuggingFaceDataset,
MultipleChoiceDataset, MultipleResponseDataset,
OpenEndedDataset, PromptDataset,
AutocompleteDataset, SystemAssistedDataset,
JudgeQuestioningDataset, InjectionDataset
)

__all__ = [
"Dataset", "HuggingFaceDataset",
"MultipleChoiceDataset", "MultipleResponseDataset",
"OpenEndedDataset", "PromptDataset",
"AutocompleteDataset", "SystemAssistedDataset",
"JudgeQuestioningDataset", "InjectionDataset"
]
147 changes: 3 additions & 144 deletions walledeval/data/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# walledeval/benchmark/core.py
# walledeval/data/core.py

from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Optional, Union
Expand All @@ -7,28 +7,11 @@
from datasets import load_dataset
import datasets

from walledeval.types import (
MultipleChoiceQuestion, MultipleResponseQuestion,
OpenEndedQuestion,
Prompt,
AutocompletePrompt,
SystemAssistedPrompt,
JudgeQuestioningPrompt,
InjectionPrompt,
Range
)
from walledeval.types import Prompt, Range
from walledeval.util import process_range

__all__ = [
"Dataset", "HuggingFaceDataset",
"MultipleChoiceDataset",
"MultipleResponseDataset",
"OpenEndedDataset",
"PromptDataset",
"AutocompleteDataset",
"SystemAssistedDataset",
"JudgeQuestioningDataset",
"InjectionDataset"
"Dataset", "HuggingFaceDataset"
]

T = TypeVar('T')
Expand Down Expand Up @@ -60,67 +43,6 @@ def __init__(self, name: str, dataset: datasets.Dataset):
super().__init__(name)
self.dataset = dataset

@classmethod
def from_hub(cls, name: str,
config: Optional[str] = None,
split: str = "DEFAULT",
**ds_kwargs):
dataset = load_dataset(name, config, **ds_kwargs)

splits = tuple(dataset.keys())

if split in splits:
dataset = dataset[split]
elif split == "DEFAULT":
if "train" in splits:
dataset = dataset["train"]
elif "test" in splits:
dataset = dataset["test"]
else:
split = splits[0]
dataset = dataset[split]
else:
raise NameError(f"Requested split '{split}' not found in dataset {name}/{config}, select one of {splits}")

return cls(
name + ("/" + config if config else "") + ("/" + split if split != "DEFAULT" else ""),
dataset
)

@classmethod
def from_list(cls, name: str, lst: list[dict]):
dataset = datasets.Dataset.from_list(lst)
return cls(name, dataset)

@classmethod
def from_csv(cls, filenames: Union[str, list[str]], **csv_kwargs):
filenames = [filenames] if isinstance(filenames, str) else filenames
dataset = load_dataset(
"csv",
data_files=filenames,
**csv_kwargs
)['train']

return cls(
filenames[0],
dataset
)

@classmethod
def from_json(cls, filenames: Union[str, list[str]], **json_kwargs):
filenames = [filenames] if isinstance(filenames, str) else filenames
dataset = load_dataset(
"json",
data_files=filenames,
**json_kwargs
)['train']

return cls(
filenames[0],
dataset
)


@abstractmethod
def convert(self, sample: dict) -> T:
pass
Expand Down Expand Up @@ -264,66 +186,3 @@ def convert(self, sample: dict) -> I:
})


class MultipleChoiceDataset(_HuggingFaceDataset[MultipleChoiceQuestion]):
def convert(self, sample: dict) -> MultipleChoiceQuestion:
return MultipleChoiceQuestion(
question=sample["question"],
choices=sample["choices"],
answer=sample["answer"]
)


class MultipleResponseDataset(
_HuggingFaceDataset[MultipleResponseQuestion]
):
def convert(self, sample: dict) -> MultipleResponseQuestion:
return MultipleResponseQuestion(
question=sample["question"],
choices=sample["choices"],
answers=sample["answers"]
)


class OpenEndedDataset(_HuggingFaceDataset[OpenEndedQuestion]):
def convert(self, sample: dict) -> OpenEndedQuestion:
return OpenEndedQuestion(
question=sample["question"]
)


class PromptDataset(_HuggingFaceDataset[Prompt]):
def convert(self, sample: dict) -> Prompt:
return Prompt(
prompt=sample["prompt"]
)


class AutocompleteDataset(_HuggingFaceDataset[AutocompletePrompt]):
def convert(self, sample: dict) -> AutocompletePrompt:
return AutocompletePrompt(
prompt=sample["prompt"]
)


class SystemAssistedDataset(_HuggingFaceDataset[SystemAssistedPrompt]):
def convert(self, sample: dict) -> SystemAssistedPrompt:
return SystemAssistedPrompt(
prompt=sample["prompt"],
system=sample["system"]
)


class JudgeQuestioningDataset(_HuggingFaceDataset[JudgeQuestioningPrompt]):
def convert(self, sample: dict) -> JudgeQuestioningPrompt:
return JudgeQuestioningPrompt(
prompt=sample["prompt"],
judge_question=sample["judge"]
)


class InjectionDataset(_HuggingFaceDataset[InjectionPrompt]):
def convert(self, sample: dict) -> InjectionPrompt:
return SystemAssistedPrompt(
prompt=sample["prompt"],
system=sample["system"]
)

0 comments on commit 20218cc

Please sign in to comment.