Skip to content

Commit

Permalink
some pre-hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Oct 27, 2023
1 parent 96c1c72 commit 6654759
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 88 deletions.
1 change: 1 addition & 0 deletions src/graphnet/models/components/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def sum_pool_and_distribute(
tensor_unpooled = tensor_pooled[inv]
return tensor_unpooled


@torch.jit.script
def _group_identical(
tensor: Tensor, batch: Optional[LongTensor] = None
Expand Down
23 changes: 10 additions & 13 deletions src/graphnet/models/gnn/rnn_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def __init__(
self,
nb_inputs: int,
*,
#nb_neighbours: int = 8,
# nb_neighbours: int = 8,
RNN_layers: int = 2,
RNN_hidden_size: int = 64,
RNN_dropout: float = 0.5,
features_subset: Optional[Union[List[int], slice]] = None,
dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
#post_processing_layer_sizes: Optional[List[int]] = None,
# post_processing_layer_sizes: Optional[List[int]] = None,
readout_layer_sizes: Optional[List[int]] = None,
global_pooling_schemes: Optional[Union[str, List[str]]] = None,
#add_global_variables_after_pooling: bool = False,
# add_global_variables_after_pooling: bool = False,
):
"""Initialize the RNN_DynEdge model.
Expand All @@ -48,10 +48,10 @@ def __init__(
global_pooling_schemes (Optional[Union[str, List[str]]], optional): Pooling schemes to use. Defaults to None.
add_global_variables_after_pooling (bool, optional): Whether to add global variables after pooling. Defaults to False.
"""
#self._nb_neighbours = nb_neighbours
# self._nb_neighbours = nb_neighbours
self._nb_inputs = nb_inputs
self._RNN_layers = RNN_layers
self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size
self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size
self._RNN_dropout = RNN_dropout

self._features_subset = features_subset
Expand Down Expand Up @@ -83,8 +83,8 @@ def __init__(
# )
if readout_layer_sizes is None:
readout_layer_sizes = [
256,
128,
256,
128,
]
self._readout_layer_sizes = readout_layer_sizes

Expand All @@ -94,7 +94,7 @@ def __init__(
num_layers=self._RNN_layers,
nb_inputs=2,
hidden_size=self._RNN_hidden_size,
#nb_neighbours=self._nb_neighbours,
# nb_neighbours=self._nb_neighbours,
RNN_dropout=self._RNN_dropout,
)

Expand All @@ -105,7 +105,6 @@ def __init__(
# dropout=self._RNN_dropout,
# )


# self._dynedge = DynEdge(
# nb_inputs=self._RNN_hidden_size + 5,
# nb_neighbours=self._nb_neighbours,
Expand All @@ -117,17 +116,15 @@ def __init__(
# )
self._dynedge_tito = DynEdgeTITO(
nb_inputs=self._RNN_hidden_size + 5,
#nb_neighbours=self._nb_neighbours,
#dyntrans_layer_sizes=self._dynedge_layer_sizes,
# nb_neighbours=self._nb_neighbours,
# dyntrans_layer_sizes=self._dynedge_layer_sizes,
features_subset=self._features_subset,
global_pooling_schemes=self._global_pooling_schemes,
)


def forward(self, data: Data) -> torch.Tensor:
"""Apply learnable forward pass of the RNN and DynEdge models."""
data = self._rnn(data)
readout = self._dynedge_tito(data)

return readout

200 changes: 137 additions & 63 deletions src/graphnet/models/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from graphnet.models.components.pool import _group_identical
from time import time


class NodeDefinition(Model): # pylint: disable=too-few-public-methods
"""Base class for graph building."""

Expand Down Expand Up @@ -123,18 +124,23 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
# sort by time
x = x[x[:, self._time_index].sort().indices]
# undo log10 scaling since we want to sum up charge
x[:,self._charge_index] = torch.pow(10,x[:,self._charge_index])
x[:, self._charge_index] = torch.pow(10, x[:, self._charge_index])
# shift time to positive values with a small offset
x[:,self._time_index] += -min(x[:,self._time_index])
x[:, self._time_index] += -min(x[:, self._time_index])
# Group pulses on the same DOM
dom_index = _group_identical(x[:, self._id_columns])

val, ind = dom_index.sort(stable=True)
counts = torch.concat([torch.tensor([0]),val.bincount().cumsum(-1)[:-1]])
unique_doms = x[:, self._id_columns+[self._time_index]][ind][counts]
counts = torch.concat(
[torch.tensor([0]), val.bincount().cumsum(-1)[:-1]]
)
unique_doms = x[:, self._id_columns + [self._time_index]][ind][counts]

time_series = [
x[dom_index == index_key][:,[self._charge_index,self._time_index]] for index_key in dom_index.unique()
x[dom_index == index_key][
:, [self._charge_index, self._time_index]
]
for index_key in dom_index.unique()
]
# add total charge to unique dom features and apply log10 scaling
charge = torch.stack(
Expand All @@ -150,77 +156,139 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:


@torch.jit.script
def log10powsum(tensor:torch.Tensor) -> torch.Tensor:
return torch.log10(torch.sum(torch.pow(10,tensor)))

def log10powsum(tensor: torch.Tensor) -> torch.Tensor:
return torch.log10(torch.sum(torch.pow(10, tensor)))


@torch.jit.script
def pad_charge(tensor:torch.Tensor, time_range:torch.Tensor) -> torch.Tensor:
def pad_charge(tensor: torch.Tensor, time_range: torch.Tensor) -> torch.Tensor:
"""Pad charge tensor to have same length as time range.
Args:
tensor: tensor of shape (num_pulses,2) with time and charge.
Returns:
tensor: padded charge tensor of shape (len(time_range),1)."""
padded = torch.ones(len(time_range))*torch.tensor(-16.)
tensor: padded charge tensor of shape (len(time_range),1).
"""
padded = torch.ones(len(time_range)) * torch.tensor(-16.0)
for val in time_range:
if (tensor[:,0] == val).any():
padded[time_range == val] = tensor[tensor[:,0] == val,1]
if (tensor[:, 0] == val).any():
padded[time_range == val] = tensor[tensor[:, 0] == val, 1]

return padded


@torch.jit.script
def return_closest(tt:torch.Tensor, tr:torch.Tensor) -> torch.Tensor:
return torch.argmin(abs(tt - tr.unsqueeze(1)),dim=0)
def return_closest(tt: torch.Tensor, tr: torch.Tensor) -> torch.Tensor:
return torch.argmin(abs(tt - tr.unsqueeze(1)), dim=0)


@torch.jit.script
def sum_charge(x:torch.Tensor,dom_index:torch.Tensor, id_columns:list[int], time_index:int, charge_index:int) -> torch.Tensor:
def sum_charge(
x: torch.Tensor,
dom_index: torch.Tensor,
id_columns: list[int],
time_index: int,
charge_index: int,
) -> torch.Tensor:
"""Sum charge of pulses in the same time bin.
Args:
tensor: tensor of shape (num_pulses,2) with time and charge.
Returns:
tensor: padded charge tensor of shape (len(time_range),1)."""
x = torch.stack([torch.hstack([
x[dom_index == index][:,id_columns+[time_index]][0],log10powsum(x[dom_index == index][:,charge_index])]) for index in torch.unique(dom_index)])
tensor: padded charge tensor of shape (len(time_range),1).
"""
x = torch.stack(
[
torch.hstack(
[
x[dom_index == index][:, id_columns + [time_index]][0],
log10powsum(x[dom_index == index][:, charge_index]),
]
)
for index in torch.unique(dom_index)
]
)
return x


@torch.jit.script
def create_time_series(x:torch.Tensor,dom_index:torch.Tensor, id_columns:list[int], time_index:int, charge_index:int, time_range:torch.Tensor) -> list[torch.Tensor]:
"""Create time series
def create_time_series(
x: torch.Tensor,
dom_index: torch.Tensor,
id_columns: list[int],
time_index: int,
charge_index: int,
time_range: torch.Tensor,
) -> list[torch.Tensor]:
"""Create time series.
Args:
tensor: padded charge tensor of shape (len(time_range),1).
dom_index: indexing of doms for grouping.
id_columns: list of columns that uniquely identify a DOM.
time_index: index of time column.
charge_index: index of charge column.
time_range: time range to be used for padding.
Returns:
tensor: time series of shape (num_pulses,1)."""
tensor: time series of shape (num_pulses,1).
"""
time_series, unique_doms = [], []
for index_key in torch.unique(dom_index):
unique_doms.append(torch.hstack([x[dom_index == index_key][0][id_columns],x[dom_index == index_key][0][time_index],x[dom_index == index_key][:,charge_index].max()]))
time_series.append(pad_charge(x[dom_index == index_key][:,[time_index]+[charge_index]],time_range))
unique_doms.append(
torch.hstack(
[
x[dom_index == index_key][0][id_columns],
x[dom_index == index_key][0][time_index],
x[dom_index == index_key][:, charge_index].max(),
]
)
)
time_series.append(
pad_charge(
x[dom_index == index_key][:, [time_index] + [charge_index]],
time_range,
)
)
unique_doms = torch.vstack(unique_doms)
time_series = torch.vstack(time_series)
return [unique_doms, time_series]


@torch.jit.script
def get_unique_dom_features(x:torch.Tensor,dom_index:torch.Tensor, id_columns:list[int], time_index:int, charge_index:int, time_range:torch.Tensor) -> torch.Tensor:
unique_doms = []
for index_key in torch.unique(dom_index):
unique_doms.append(torch.hstack([x[dom_index == index_key][0][id_columns],x[dom_index == index_key][0][time_index],x[dom_index == index_key][:,charge_index].max()]))
unique_doms = torch.vstack(unique_doms)
return unique_doms
def get_unique_dom_features(
x: torch.Tensor,
dom_index: torch.Tensor,
id_columns: list[int],
time_index: int,
charge_index: int,
time_range: torch.Tensor,
) -> torch.Tensor:
unique_doms = []
for index_key in torch.unique(dom_index):
unique_doms.append(
torch.hstack(
[
x[dom_index == index_key][0][id_columns],
x[dom_index == index_key][0][time_index],
x[dom_index == index_key][:, charge_index].max(),
]
)
)
unique_doms = torch.vstack(unique_doms)
return unique_doms


def create_sparse_charge_series(dom_index:torch.Tensor,time_index:torch.Tensor,values:torch.Tensor) -> torch.Tensor:
i = torch.vstack([dom_index,time_index])
def create_sparse_charge_series(
dom_index: torch.Tensor, time_index: torch.Tensor, values: torch.Tensor
) -> torch.Tensor:
i = torch.vstack([dom_index, time_index])
v = values
s = torch.sparse_coo_tensor(i, v, (dom_index.max()+1, time_index.max()+1))
s = torch.sparse_coo_tensor(
i, v, (dom_index.max() + 1, time_index.max() + 1)
)
s = s.coalesce()
return s

Expand All @@ -237,7 +305,7 @@ def create_sparse_charge_series(dom_index:torch.Tensor,time_index:torch.Tensor,v
# sorted_time_series = [time_series[i] for i in sort_index]
# return sorted_time_series, sort_index


class NodesAsDOMTimeWindow(NodeDefinition):
"""Represent each node as DOM a with a time series of pulses."""

Expand Down Expand Up @@ -272,41 +340,56 @@ def __init__(
self._granularity = granularity

super().__init__()


def _construct_nodes(self, x: torch.Tensor) -> Data:
"""Construct nodes from raw node features ´x´."""
# sort by time
x = x[x[:, self._time_index].sort().indices]
# undo log10 scaling since we want to sum up charge
x[:,self._charge_index] = torch.pow(10,x[:,self._charge_index])
x[:, self._charge_index] = torch.pow(10, x[:, self._charge_index])
# shift time to positive values with a small offset
x[:,self._time_index] += (0.1-min(x[:,self._time_index]))
x[:, self._time_index] += 0.1 - min(x[:, self._time_index])

# create time range
self._time_range = torch.logspace(torch.log2(min(x[:,self._time_index])),(torch.log2(max(x[:,self._time_index]))),self._granularity,base=2.)
self._time_range = torch.logspace(
torch.log2(min(x[:, self._time_index])),
(torch.log2(max(x[:, self._time_index]))),
self._granularity,
base=2.0,
)

# group pulses on the same DOM
# group pulses on the same DOM
dom_index = _group_identical(x[:, self._id_columns])

# get unique dom features
# unique_doms = get_unique_dom_features(x,dom_index,self._id_columns,self._time_index,self._charge_index,self._time_range)
val, ind = dom_index.sort(stable=True)
counts = torch.concat([torch.tensor([0]),val.bincount().cumsum(-1)[:-1]])
unique_doms = x[:, self._id_columns+[self._time_index]][ind][counts]
counts = torch.concat(
[torch.tensor([0]), val.bincount().cumsum(-1)[:-1]]
)
unique_doms = x[:, self._id_columns + [self._time_index]][ind][counts]

# get coarse time index
coarse_time_index = return_closest(x[:,self._time_index],self._time_range)
coarse_time_index = return_closest(
x[:, self._time_index], self._time_range
)

# Create torch sparse tensor summing up charge in the same time bin
time_series = create_sparse_charge_series(dom_index,coarse_time_index,x[:,self._charge_index])
time_series = create_sparse_charge_series(
dom_index, coarse_time_index, x[:, self._charge_index]
)

# add total charge to unique dom features
unique_doms = torch.hstack([unique_doms,torch._sparse_sum(time_series,dim=1).to_dense().unsqueeze(1)])
unique_doms = torch.hstack(
[
unique_doms,
torch._sparse_sum(time_series, dim=1).to_dense().unsqueeze(1),
]
)
# apply inverse hyperbolic sine to charge values (handles zeros unlike log scaling)
unique_doms[:,-1] = torch.asinh(5*unique_doms[:,-1])/5
time_series = torch.asinh(5*time_series)/5

unique_doms[:, -1] = torch.asinh(5 * unique_doms[:, -1]) / 5
time_series = torch.asinh(5 * time_series) / 5
# convert to dense tensor

time_series = time_series.to_dense()
Expand All @@ -321,13 +404,4 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:

# unique_doms, time_series = create_time_series(x,dom_index,self._id_columns,self._time_index,self._charge_index,self._time_range)


return Data(x=unique_doms, time_series = time_series)








return Data(x=unique_doms, time_series=time_series)
Loading

0 comments on commit 6654759

Please sign in to comment.