Skip to content

Commit

Permalink
Merge branch 'main' into automatic_ensemble
Browse files Browse the repository at this point in the history
merge from main
  • Loading branch information
Aske-Rosted committed Mar 21, 2024
2 parents 9cf8ad3 + 205b20e commit a4f2ef2
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/graphnet/data/utilities/sqlite_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]:
query = 'SELECT name FROM sqlite_master WHERE type == "table"'
table_names = [table[0] for table in conn.execute(query).fetchall()]

assert len(table_names) > 0, "No tables found in database."

integer_primary_key = {}
for table in table_names:
query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;"
Expand Down
9 changes: 8 additions & 1 deletion src/graphnet/data/writers/sqlite_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,14 @@ def _merge_databases(
for file_count, input_file in tqdm(enumerate(files), colour="green"):

# Extract table names and index column name in database
tables, primary_key = get_primary_keys(database=input_file)
try:
tables, primary_key = get_primary_keys(database=input_file)
except AssertionError as e:
if "No tables found in database." in str(e):
self.warning(f"Database {input_file} is empty. Skipping.")
continue
else:
raise e

for table_name in tables.keys():
# Extract all data in the table from the given database
Expand Down
8 changes: 6 additions & 2 deletions src/graphnet/deployment/deployment_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List, Union, Dict

import numpy as np
from torch import Tensor
from torch import Tensor, load
from torch_geometric.data import Data, Batch

from graphnet.models import Model
Expand Down Expand Up @@ -61,7 +61,11 @@ def _load_model(
) -> Model:
"""Load `Model` from config and insert learned weights."""
model = Model.from_config(model_config, trust=True)
model.load_state_dict(state_dict)
if isinstance(state_dict, str) and state_dict.endswith(".ckpt"):
ckpt = load(state_dict)
model.load_state_dict(ckpt["state_dict"])
else:
model.load_state_dict(state_dict)
return model

def _resolve_prediction_columns(
Expand Down
2 changes: 2 additions & 0 deletions src/graphnet/models/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
"""Construct nodes from raw node features ´x´."""
# Cast to Numpy
x = x.numpy()
if x.shape[0] == 0:
return Data(x=torch.tensor(np.column_stack([x, []])))
# if there is no charge column add a dummy column of zeros with the same shape as the time column
if self._charge_index is None:
charge_index: int = len(self._keys)
Expand Down
39 changes: 28 additions & 11 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def predict_as_dataframe(
f"number of output columns ({predictions.shape[1]}) don't match."
)

# Check if predictions are on event- or pulse-level
pulse_level_predictions = len(predictions) > len(dataloader.dataset)

# Get additional attributes
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
Expand All @@ -426,25 +429,39 @@ def predict_as_dataframe(
# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if len(predictions) != len(dataloader.dataset):
if pulse_level_predictions:
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
)
pass
attributes[attr].extend(attribute)

# Confirm that attributes match length of predictions
skip_attributes = []
for attr in attributes.keys():
try:
assert len(attributes[attr]) == len(predictions)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f" of additional attribute '{attr}' to match length of"
f" predictions.This error can be caused by heavy"
" disagreement between number of examples in the"
" dataset vs. actual events in the dataloader, e.g. "
" heavy filtering of events in `collate_fn` passed to"
" `dataloader`. This can also be caused by requesting"
" pulse-level attributes for `Task`s that produce"
" event-level predictions. Attribute skipped."
)
skip_attributes.append(attr)

# Remove bad attributes
for attr in skip_attributes:
attributes.pop(attr)
additional_attributes.remove(attr)

data = np.concatenate(
[predictions]
+ [
Expand Down
6 changes: 6 additions & 0 deletions src/graphnet/training/weight_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def fit(
add_to_database: bool = False,
selection: Optional[List[int]] = None,
transform: Optional[Callable] = None,
db_count_norm: Optional[int] = None,
**kwargs: Any,
) -> pd.DataFrame:
"""Fit weights.
Expand All @@ -73,6 +74,7 @@ def fit(
transform: A callable method that transform the variable into a
desired space. E.g. np.log10 for energy. If given, fitting will
happen in this space.
db_count_norm: If given, the total sum of the weights for the given db will be this number.
**kwargs: Additional arguments passed to `_fit_weights`.
Returns:
Expand All @@ -94,6 +96,10 @@ def fit(
if self._transform is not None:
truth[self._variable] = self._transform(truth[self._variable])
weights = self._fit_weights(truth, **kwargs)
if db_count_norm is not None:
weights[self._weight_name] = (
weights[self._weight_name] * db_count_norm / len(weights)
)

if add_to_database:
create_table_and_save_to_sql(
Expand Down

0 comments on commit a4f2ef2

Please sign in to comment.