Skip to content

Commit

Permalink
Merge pull request graphnet-team#610 from ArturoLlorente/feature/sequ…
Browse files Browse the repository at this point in the history
…ence_bucketing

Feature/sequence bucketing
  • Loading branch information
ArturoLlorente authored Oct 13, 2023
2 parents 9b9dd63 + 8272bf3 commit 02902fa
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
46 changes: 36 additions & 10 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,22 @@ def configure_optimizers(self) -> Dict[str, Any]:
)
return config

def forward(self, data: Data) -> List[Union[Tensor, Data]]:
def forward(
self, data: Union[Data, List[Data]]
) -> List[Union[Tensor, Data]]:
"""Forward pass, chaining model components."""
assert isinstance(data, Data)
x = self._gnn(data)
if isinstance(data, Data):
data = [data]
x_list = []
for d in data:
x = self._gnn(d)
x_list.append(x)
x = torch.cat(x_list, dim=0)

preds = [task(x) for task in self._tasks]
return preds

def shared_step(self, batch: Data, batch_idx: int) -> Tensor:
def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
"""Perform shared step.
Applies the forward pass and the following loss calculation, shared
Expand All @@ -111,8 +119,12 @@ def shared_step(self, batch: Data, batch_idx: int) -> Tensor:
loss = self.compute_loss(preds, batch)
return loss

def training_step(self, train_batch: Data, batch_idx: int) -> Tensor:
def training_step(
self, train_batch: Union[Data, List[Data]], batch_idx: int
) -> Tensor:
"""Perform training step."""
if isinstance(train_batch, Data):
train_batch = [train_batch]
loss = self.shared_step(train_batch, batch_idx)
self.log(
"train_loss",
Expand All @@ -125,8 +137,12 @@ def training_step(self, train_batch: Data, batch_idx: int) -> Tensor:
)
return loss

def validation_step(self, val_batch: Data, batch_idx: int) -> Tensor:
def validation_step(
self, val_batch: Union[Data, List[Data]], batch_idx: int
) -> Tensor:
"""Perform validation step."""
if isinstance(val_batch, Data):
val_batch = [val_batch]
loss = self.shared_step(val_batch, batch_idx)
self.log(
"val_loss",
Expand All @@ -140,11 +156,21 @@ def validation_step(self, val_batch: Data, batch_idx: int) -> Tensor:
return loss

def compute_loss(
self, preds: Tensor, data: Data, verbose: bool = False
self, preds: Tensor, data: List[Data], verbose: bool = False
) -> Tensor:
"""Compute and sum losses across tasks."""
data_merged = {}
target_labels_merged = list(set(self.target_labels))
for label in target_labels_merged:
data_merged[label] = torch.cat([d[label] for d in data], dim=0)
for task in self._tasks:
if task._loss_weight is not None:
data_merged[task._loss_weight] = torch.cat(
[d[task._loss_weight] for d in data], dim=0
)

losses = [
task.compute_loss(pred, data)
task.compute_loss(pred, data_merged)
for task, pred in zip(self._tasks, preds)
]
if verbose:
Expand All @@ -154,8 +180,8 @@ def compute_loss(
), "Please reduce loss for each task separately"
return torch.sum(torch.stack(losses))

def _get_batch_size(self, data: Data) -> int:
return torch.numel(torch.unique(data.batch))
def _get_batch_size(self, data: List[Data]) -> int:
return sum([torch.numel(torch.unique(d.batch)) for d in data])

def inference(self) -> None:
"""Activate inference mode."""
Expand Down
38 changes: 38 additions & 0 deletions src/graphnet/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,44 @@ def collate_fn(graphs: List[Data]) -> Batch:
return Batch.from_data_list(graphs)


class collator_sequence_buckleting:
"""Perform the sequence bucketing for the graphs in the batch."""

def __init__(self, batch_splits: List[float] = [0.8]):
"""Set cutting points of the different mini-batches.
batch_splits: list of floats, each element is the fraction of the total
number of graphs. This list should not explicitly define the first and
last elements, which will always be 0 and 1 respectively.
"""
self.batch_splits = batch_splits

def __call__(self, graphs: List[Data]) -> Batch:
"""Execute sequence bucketing on the input list of graphs.
Args:
graphs: A list of Data objects representing the input graphs.
Returns:
A list of Batch objects, each containing a mini-batch of the input
graphs sorted by their number of pulses.
"""
graphs = [g for g in graphs if g.n_pulses > 1]
graphs.sort(key=lambda x: x.n_pulses)
batch_list = []

for minp, maxp in zip(
[0] + self.batch_splits, self.batch_splits + [1]
):
min_idx = int(minp * len(graphs))
max_idx = int(maxp * len(graphs))
this_graphs = graphs[min_idx:max_idx]
if len(this_graphs) > 0:
this_batch = Batch.from_data_list(this_graphs)
batch_list.append(this_batch)
return batch_list


# @TODO: Remove in favour of DataLoader{,.from_dataset_config}
def make_dataloader(
db: str,
Expand Down

0 comments on commit 02902fa

Please sign in to comment.