diff --git a/src/graphnet/data/dataloader.py b/src/graphnet/data/dataloader.py index 4c02c595c..bbb189d12 100644 --- a/src/graphnet/data/dataloader.py +++ b/src/graphnet/data/dataloader.py @@ -5,7 +5,7 @@ import torch.utils.data from torch_geometric.data import Batch, Data -from graphnet.data.dataset import Dataset +from graphnet.data.dataset import Dataset, EnsembleDataset from graphnet.utilities.config import DatasetConfig @@ -63,16 +63,33 @@ def from_dataset_config( "`shuffle` is automatically inferred from the selection name, " "and thus should not specified as an argument." ) - datasets = Dataset.from_config(config) + + if isinstance(config.path, list): + datasets: Union[Dict[str, Dataset], Dict[str, EnsembleDataset]] = {} # type: ignore + dataset_col: Dict[str, list] = {} + for key in config.selection.keys(): + dataset_col[key] = [] + save_path = config.path.copy() + for path in config.path: + config.path = path + tmp_dataset: Dict[str, Dataset] = Dataset.from_config( + config + ) + for key in config.selection.keys(): + dataset_col[key].append(tmp_dataset[key]) + config.path = save_path + for key in config.selection.keys(): + datasets[key] = EnsembleDataset(dataset_col[key]) + else: + datasets = Dataset.from_config(config) assert isinstance(datasets, dict) data_loaders: Dict[str, DataLoader] = {} - for name, dataset in datasets.items(): + for name, dataset_item in datasets.items(): data_loaders[name] = cls( - dataset, + dataset_item, shuffle=do_shuffle(name), **kwargs, ) - return data_loaders else: @@ -80,6 +97,17 @@ def from_dataset_config( "When passing a `DatasetConfig` with a single selections, you " "need to specify `shuffle` as an argument." ) - dataset = Dataset.from_config(config) - assert isinstance(dataset, Dataset) - return cls(dataset, **kwargs) + if isinstance(config.path, list): + dataset_list: List[Any] = [] + save_path = config.path.copy() + for path in config.path: + config.path = path + dataset_list.append(Dataset.from_config(config)) + ensembleset = EnsembleDataset(dataset_list) + config.path = save_path + assert isinstance(ensembleset, EnsembleDataset) + return cls(ensembleset, **kwargs) + else: + dataset = Dataset.from_config(config) + assert isinstance(dataset, Dataset) + return cls(dataset, **kwargs)