Skip to content

Commit

Permalink
Merge pull request #775 from Aske-Rosted/file_list_ensemble
Browse files Browse the repository at this point in the history
Ensemble from list of paths
  • Loading branch information
Aske-Rosted authored Dec 12, 2024
2 parents 92b150a + 79b0479 commit b4f7cbc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,16 @@ def from_config( # type: ignore[override]
cfg["graph_definition"] = parse_graph_definition(cfg)
if cfg["labels"] is not None:
cfg["labels"] = parse_labels(cfg)
return source._dataset_class(**cfg)

if isinstance(cfg["path"], list):
sources = []
for path in cfg["path"]:
cfg["path"] = path
sources.append(source._dataset_class(**cfg))
source = EnsembleDataset(sources)
return source
else:
return source._dataset_class(**cfg)

@classmethod
def concatenate(
Expand Down
19 changes: 19 additions & 0 deletions tests/utilities/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,22 @@ def test_dataset_config_files(backend: str) -> None:
)
== 0
)


@pytest.mark.order(6)
@pytest.mark.parametrize("backend", ["sqlite"])
def test_multiple_dataset_config_dict_selection(backend: str) -> None:
"""Test constructing Dataset with multiple data paths."""
# Arrange
config_path = CONFIG_PATHS[backend]

# Single dataset
config = DatasetConfig.load(config_path)
dataset = Dataset.from_config(config)
# Construct multiple datasets
config_ensemble = DatasetConfig.load(config_path)
config_ensemble.path = [config_ensemble.path, config_ensemble.path]

ensemble_dataset = Dataset.from_config(config_ensemble)

assert len(dataset) * 2 == len(ensemble_dataset)

0 comments on commit b4f7cbc

Please sign in to comment.