From 95ba8df994a8910a6c16f30e94d84643fe98dbcd Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Wed, 11 Dec 2024 14:43:34 +0900 Subject: [PATCH] Ensemble from list of paths --- src/graphnet/data/dataset/dataset.py | 11 ++++++++++- tests/utilities/test_dataset_config.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index 110da8a73..db274d906 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -141,7 +141,16 @@ def from_config( # type: ignore[override] cfg["graph_definition"] = parse_graph_definition(cfg) if cfg["labels"] is not None: cfg["labels"] = parse_labels(cfg) - return source._dataset_class(**cfg) + + if isinstance(cfg["path"], list): + sources = [] + for path in cfg["path"]: + cfg["path"] = path + sources.append(source._dataset_class(**cfg)) + source = EnsembleDataset(sources) + return source + else: + return source._dataset_class(**cfg) @classmethod def concatenate( diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py index f54e37840..b3d6aec11 100644 --- a/tests/utilities/test_dataset_config.py +++ b/tests/utilities/test_dataset_config.py @@ -269,3 +269,22 @@ def test_dataset_config_files(backend: str) -> None: ) == 0 ) + + +@pytest.mark.order(6) +@pytest.mark.parametrize("backend", ["sqlite"]) +def test_multiple_dataset_config_dict_selection(backend: str) -> None: + """Test constructing Dataset with multiple data paths.""" + # Arrange + config_path = CONFIG_PATHS[backend] + + # Single dataset + config = DatasetConfig.load(config_path) + dataset = Dataset.from_config(config) + # Construct multiple datasets + config_ensemble = DatasetConfig.load(config_path) + config_ensemble.path = [config_ensemble.path, config_ensemble.path] + + ensemble_dataset = Dataset.from_config(config_ensemble) + + assert len(dataset) * 2 == len(ensemble_dataset)