diff --git a/walledeval/data/core.py b/walledeval/data/core.py index d536649..f3a3e8b 100644 --- a/walledeval/data/core.py +++ b/walledeval/data/core.py @@ -108,7 +108,7 @@ def from_csv(cls, filenames: Union[str, list[str]], **csv_kwargs): @classmethod def from_json(cls, filenames: Union[str, list[str]], **json_kwargs): - [filenames] if isinstance(filenames, str) else filenames + filenames = [filenames] if isinstance(filenames, str) else filenames dataset = load_dataset( "json", data_files=filenames, @@ -223,9 +223,10 @@ def from_list(cls, name: str, lst: list[dict], model: type = Prompt): @classmethod def from_csv(cls, filenames: Union[str, list[str]], model: type = Prompt, **csv_kwargs): + filenames = [filenames] if isinstance(filenames, str) else filenames dataset = load_dataset( "csv", - data_files=[filenames] if isinstance(filenames, str) else filenames, + data_files=filenames, **csv_kwargs )['train'] @@ -237,9 +238,10 @@ def from_csv(cls, filenames: Union[str, list[str]], model: type = Prompt, **csv_ @classmethod def from_json(cls, filenames: Union[str, list[str]], model: type = Prompt, **json_kwargs): + filenames = [filenames] if isinstance(filenames, str) else filenames dataset = load_dataset( "json", - data_files=[filenames] if isinstance(filenames, str) else filenames, + data_files=filenames, **json_kwargs )['train']