Skip to content

Commit

Permalink
fix(dataset): fix naming error for files loaded from files
Browse files Browse the repository at this point in the history
  • Loading branch information
ThePyProgrammer committed Aug 2, 2024
1 parent 7307939 commit e687fd1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions walledeval/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def from_csv(cls, filenames: Union[str, list[str]], **csv_kwargs):

@classmethod
def from_json(cls, filenames: Union[str, list[str]], **json_kwargs):
[filenames] if isinstance(filenames, str) else filenames
filenames = [filenames] if isinstance(filenames, str) else filenames
dataset = load_dataset(
"json",
data_files=filenames,
Expand Down Expand Up @@ -223,9 +223,10 @@ def from_list(cls, name: str, lst: list[dict], model: type = Prompt):

@classmethod
def from_csv(cls, filenames: Union[str, list[str]], model: type = Prompt, **csv_kwargs):
filenames = [filenames] if isinstance(filenames, str) else filenames
dataset = load_dataset(
"csv",
data_files=[filenames] if isinstance(filenames, str) else filenames,
data_files=filenames,
**csv_kwargs
)['train']

Expand All @@ -237,9 +238,10 @@ def from_csv(cls, filenames: Union[str, list[str]], model: type = Prompt, **csv_

@classmethod
def from_json(cls, filenames: Union[str, list[str]], model: type = Prompt, **json_kwargs):
filenames = [filenames] if isinstance(filenames, str) else filenames
dataset = load_dataset(
"json",
data_files=[filenames] if isinstance(filenames, str) else filenames,
data_files=filenames,
**json_kwargs
)['train']

Expand Down

0 comments on commit e687fd1

Please sign in to comment.