Skip to content

Commit

Permalink
Merge branch 'automatic_ensemble' into automatic_ensemble_new
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed May 17, 2024
2 parents c579dc6 + ad17223 commit d0d9874
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions src/graphnet/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.utils.data
from torch_geometric.data import Batch, Data

from graphnet.data.dataset import Dataset
from graphnet.data.dataset import Dataset, EnsembleDataset
from graphnet.utilities.config import DatasetConfig


Expand Down Expand Up @@ -63,23 +63,51 @@ def from_dataset_config(
"`shuffle` is automatically inferred from the selection name, "
"and thus should not specified as an argument."
)
datasets = Dataset.from_config(config)

if isinstance(config.path, list):
datasets: Union[Dict[str, Dataset], Dict[str, EnsembleDataset]] = {} # type: ignore
dataset_col: Dict[str, list] = {}
for key in config.selection.keys():
dataset_col[key] = []
save_path = config.path.copy()
for path in config.path:
config.path = path
tmp_dataset: Dict[str, Dataset] = Dataset.from_config(
config
)
for key in config.selection.keys():
dataset_col[key].append(tmp_dataset[key])
config.path = save_path
for key in config.selection.keys():
datasets[key] = EnsembleDataset(dataset_col[key])
else:
datasets = Dataset.from_config(config)
assert isinstance(datasets, dict)
data_loaders: Dict[str, DataLoader] = {}
for name, dataset in datasets.items():
for name, dataset_item in datasets.items():
data_loaders[name] = cls(
dataset,
dataset_item,
shuffle=do_shuffle(name),
**kwargs,
)

return data_loaders

else:
assert "shuffle" in kwargs, (
"When passing a `DatasetConfig` with a single selections, you "
"need to specify `shuffle` as an argument."
)
dataset = Dataset.from_config(config)
assert isinstance(dataset, Dataset)
return cls(dataset, **kwargs)
if isinstance(config.path, list):
dataset_list: List[Any] = []
save_path = config.path.copy()
for path in config.path:
config.path = path
dataset_list.append(Dataset.from_config(config))
ensembleset = EnsembleDataset(dataset_list)
config.path = save_path
assert isinstance(ensembleset, EnsembleDataset)
return cls(ensembleset, **kwargs)
else:
dataset = Dataset.from_config(config)
assert isinstance(dataset, Dataset)
return cls(dataset, **kwargs)

0 comments on commit d0d9874

Please sign in to comment.