Skip to content

Commit

Permalink
Merge pull request #624 from AMHermansen/model-refactor
Browse files Browse the repository at this point in the history
Refactored model and standardmodel
  • Loading branch information
AMHermansen authored Nov 6, 2023
2 parents cddc567 + 39879ef commit baf5e37
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 268 deletions.
260 changes: 5 additions & 255 deletions src/graphnet/models/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
"""Base class(es) for building models."""

from abc import ABC, abstractmethod
from collections import OrderedDict
from abc import ABC
import dill
import os.path
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.logger import Logger as LightningLogger
from pytorch_lightning import LightningModule
import torch
from torch import Tensor
from torch.utils.data import DataLoader, SequentialSampler
from torch_geometric.data import Data

from graphnet.utilities.logging import Logger
Expand All @@ -23,258 +15,16 @@
ModelConfig,
ModelConfigSaverABC,
)
from graphnet.training.callbacks import ProgressBar


class Model(
Logger, Configurable, LightningModule, ABC, metaclass=ModelConfigSaverABC
):
"""Base class for all models in graphnet."""

@abstractmethod
def forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]:
"""Forward pass."""
"""Base class for all components in graphnet."""

@staticmethod
def _construct_trainer(
max_epochs: int = 10,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> Trainer:

if gpus:
accelerator = "gpu"
devices = gpus
else:
accelerator = "cpu"
devices = 1

trainer = Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=max_epochs,
callbacks=callbacks,
log_every_n_steps=log_every_n_steps,
logger=logger,
gradient_clip_val=gradient_clip_val,
strategy=distribution_strategy,
default_root_dir=ckpt_path,
**trainer_kwargs,
)

return trainer

def fit(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
*,
max_epochs: int = 10,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> None:
"""Fit `Model` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
)
elif val_dataloader is not None:
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
)

self.train(mode=True)
trainer = self._construct_trainer(
max_epochs=max_epochs,
gpus=gpus,
callbacks=callbacks,
ckpt_path=ckpt_path,
logger=logger,
log_every_n_steps=log_every_n_steps,
gradient_clip_val=gradient_clip_val,
distribution_strategy=distribution_strategy,
**trainer_kwargs,
)

try:
trainer.fit(
self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
)
except KeyboardInterrupt:
self.warning("[ctrl+c] Exiting gracefully.")
pass

def _create_default_callbacks(self, val_dataloader: DataLoader) -> List:
callbacks = [ProgressBar()]
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
)
return callbacks

def _add_early_stopping(
self, val_dataloader: DataLoader, callbacks: List
) -> List:
if val_dataloader is None:
return callbacks
has_early_stopping = False
assert isinstance(callbacks, list)
for callback in callbacks:
if isinstance(callback, EarlyStopping):
has_early_stopping = True

if not has_early_stopping:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=5,
)
)
self.warning_once(
"Got validation dataloader but no EarlyStopping callback. An "
"EarlyStopping callback has been added automatically with "
"patience=5 and monitor = 'val_loss'."
)
return callbacks

def predict(
self,
dataloader: DataLoader,
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = "auto",
) -> List[Tensor]:
"""Return predictions for `dataloader`.
Returns a list of Tensors, one for each model output.
"""
self.train(mode=False)

callbacks = self._create_default_callbacks(
val_dataloader=None,
)

inference_trainer = self._construct_trainer(
gpus=gpus,
distribution_strategy=distribution_strategy,
callbacks=callbacks,
)

predictions_list = inference_trainer.predict(self, dataloader)
assert len(predictions_list), "Got no predictions"

nb_outputs = len(predictions_list[0])
predictions: List[Tensor] = [
torch.cat([preds[ix] for preds in predictions_list], dim=0)
for ix in range(nb_outputs)
]

return predictions

def predict_as_dataframe(
self,
dataloader: DataLoader,
prediction_columns: List[str],
*,
additional_attributes: Optional[List[str]] = None,
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = "auto",
) -> pd.DataFrame:
"""Return predictions for `dataloader` as a DataFrame.
Include `additional_attributes` as additional columns in the output
DataFrame.
"""
# Check(s)
if additional_attributes is None:
additional_attributes = []
assert isinstance(additional_attributes, list)

if (
not isinstance(dataloader.sampler, SequentialSampler)
and additional_attributes
):
print(dataloader.sampler)
raise UserWarning(
"DataLoader has a `sampler` that is not `SequentialSampler`, "
"indicating that shuffling is enabled. Using "
"`predict_as_dataframe` with `additional_attributes` assumes "
"that the sequence of batches in `dataloader` are "
"deterministic. Either call this method a `dataloader` which "
"doesn't resample batches; or do not request "
"`additional_attributes`."
)
self.info(f"Column names for predictions are: \n {prediction_columns}")
predictions_torch = self.predict(
dataloader=dataloader,
gpus=gpus,
distribution_strategy=distribution_strategy,
)
predictions = (
torch.cat(predictions_torch, dim=1).detach().cpu().numpy()
)
assert len(prediction_columns) == predictions.shape[1], (
f"Number of provided column names ({len(prediction_columns)}) and "
f"number of output columns ({predictions.shape[1]}) don't match."
)

# Get additional attributes
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
)

for batch in dataloader:
for attr in attributes:
attribute = batch[attr]
if isinstance(attribute, torch.Tensor):
attribute = attribute.detach().cpu().numpy()

# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if len(predictions) != len(dataloader.dataset):
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
)
pass
attributes[attr].extend(attribute)

data = np.concatenate(
[predictions]
+ [
np.asarray(values)[:, np.newaxis]
for values in attributes.values()
],
axis=1,
)

results = pd.DataFrame(
data, columns=prediction_columns + additional_attributes
)
return results
def _get_batch_size(data: List[Data]) -> int:
return sum([torch.numel(torch.unique(d.batch)) for d in data])

def save(self, path: str) -> None:
"""Save entire model to `path`."""
Expand Down
Loading

0 comments on commit baf5e37

Please sign in to comment.