-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Definition #558
Graph Definition #558
Conversation
I have addressed the initial comments from @AMHermansen, refactored the last few detectors and introduced GraphDefinition in the deployment code. In relation to this, I've created #560 which contains a list of to-do items for the deployment modules which fall outside the scope of this PR. What's left is to update |
features: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] | ||
graph_definition: | ||
arguments: | ||
columns: [0, 1, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the naming of this "columns" arguments (which I understand as the columns included in the "distance" calculation of your edge definition) becomes very vague. At first glance it hard to tell what this feature does. Something like "edge_defining_columns" would be more descriptive I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the name of that argument is not unambiguous, but I also fail to find better alternatives that are not very long. The name of this argument is the same as what we used to call it in the KNNGraphBuilder (see here) and I do think that the doc string for this argument (see here) is pretty clear. So even though it's a bit challenging to understand what that argument does when one reads the config file, I think we should keep it as-is and refer users to the docs instead (which is the intended usage anyway). Is that OK with you? @Aske-Rosted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely fine.
|
||
torch.multiprocessing.set_sharing_strategy("file_system") | ||
|
||
del has_torch_package |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for deleting the has_torch_package
here but keeping it in src/graphnet/data/dataset/sqlite/__init__.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't know. I did not add this line; the difference you point out here is present in the main branch currently also. I have added the del
statement to the sqlite part now too.
@@ -133,7 +133,7 @@ def _construct_model( | |||
fn_kwargs={"trust": trust}, | |||
) | |||
|
|||
# Construct model based on arguments | |||
# Construct model based on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this change intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nopes. Fixed.
I have just finished looking through everything, I think it looks great! I also believe that this will add the necessary flexibility for a lot of the things I've been trying to implement, where I previously had to go make changes in the dataset class e.g. #521 . I am looking forward to trying it out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @RasmusOrsoe I've noticed some more minor things.
) -> Data: | ||
for idx, feature in enumerate(node_feature_names): | ||
try: | ||
node_features[:, idx] = self.feature_map()[feature]( # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if you're supposed to call self.feature_map, since it is classified as a property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey. Thanks for pointing this out. In fact, the property decorator in detector.py
is poorly placed - it is actually completely redundant. I've removed it.
# Assume all features in Detector is used. | ||
node_feature_names = list(self._detector.feature_map().keys()) # type: ignore | ||
self._node_feature_names = node_feature_names | ||
if dtype is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand this code correctly you're just getting the same behavior as if you changed the default value of dtype in the constructor to torch.float instead of None. But in a slightly more more complicated way :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right. Fixed :-)
Thank you very much for your comments. I believe I have now addressed all of them, and I have updated the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
…n_debug Graph Definition
This PR addresses the ongoing discussion in #462 and #521 (which has been a roadblock for a while) by changing
Model
such that it now consists of the modulesModel = [GraphDefinition, GNN, Task]
Where
GraphDefinition
is a single, problem/model dependent class that contains all the code responsible for data representations.TLDR: Model, Dataset and GraphNeTI3Module now depends on GraphDefinition, which allows us to easily represent data as sequences, images, or whatever your heart desires. This change is breaking; older config files and pickled models are not compatible with these changes, but state_dicts are.
Conceptually,
GraphDefinition
contains all the code that alters the raw data fromDataset
before it's passed toGNN
. It's a single, swapable module that can be passed toDataset
and deployment modules.GraphDefinition
consists of multiple submodules, and the data flow isGraphDefinition = [Detector, NodeDefinition, EdgeDefinition]
and can be seen here. The definition exists at a point in the dataflow where events are unbatched, meaning that the construction of data representations can be done on CPU and in parallel, before it's batched and sent to the GPU. That means that the sequence creation included in #521 becomes much simpler and likely faster @Aske-Rosted, and should also be useful for the transformer exploration by @MoustHolmes.The modules are defined as
NodeDefinition :
A generic class that defines what a node represents. Problem-specific versions can be implemented by overwriting the abstract method
_construct_nodes
is the playground we've been missing for a while; it gives us the freedom to fully define exactly how we want the data to be structured for our Models. Here, one can use nodes to represent DOMs (by usingCoarsening
or some other method), create images for CNNs, define sequences or other forms of data representations. Our standard of representing pulses as nodes is justEdgeDefinition:
A generic class that defines how edges are drawn between nodes in the graph. This is essentially a refactor of our
GraphBuilder
. One can create problem-specific implementations by overwriting the abstract methodDetector:
Virtually unchanged from it's known form. In charge of standardizing data and is now able to work on a subset of the feature space that it is defined on. I cleaned the class up a little bit. In the future, it will hold detector-specific geometry tables as mentioned in #462.
Our usual k-nn graph with nodes representing pulses can then be created like so:
Alternatively, you can also just import this graph definition directly, as it's included in the PR:
It is the problem-specific implementation of a graph definition that defines the number of input parameters to our GNNs, available through
graph_definition.nb_outputs
. When we instantiate a Model, the syntax is now:Other things to note:
Dataset
is now simpler, as graph-altering code has been moved toGraphDefinition
graphnet.data
to avoid circular imports.