diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index 49329e4ac..838ea40b7 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -141,7 +141,16 @@ def from_config( # type: ignore[override] cfg["graph_definition"] = parse_graph_definition(cfg) if cfg["labels"] is not None: cfg["labels"] = parse_labels(cfg) - return source._dataset_class(**cfg) + + if isinstance(cfg["path"], list): + sources = [] + for path in cfg["path"]: + cfg["path"] = path + sources.append(source._dataset_class(**cfg)) + source = EnsembleDataset(sources) + return source + else: + return source._dataset_class(**cfg) @classmethod def concatenate( diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py index f54e37840..b3d6aec11 100644 --- a/tests/utilities/test_dataset_config.py +++ b/tests/utilities/test_dataset_config.py @@ -269,3 +269,22 @@ def test_dataset_config_files(backend: str) -> None: ) == 0 ) + + +@pytest.mark.order(6) +@pytest.mark.parametrize("backend", ["sqlite"]) +def test_multiple_dataset_config_dict_selection(backend: str) -> None: + """Test constructing Dataset with multiple data paths.""" + # Arrange + config_path = CONFIG_PATHS[backend] + + # Single dataset + config = DatasetConfig.load(config_path) + dataset = Dataset.from_config(config) + # Construct multiple datasets + config_ensemble = DatasetConfig.load(config_path) + config_ensemble.path = [config_ensemble.path, config_ensemble.path] + + ensemble_dataset = Dataset.from_config(config_ensemble) + + assert len(dataset) * 2 == len(ensemble_dataset)