From da9c9c745976c95b58fdac50249fe709bb45db40 Mon Sep 17 00:00:00 2001 From: Thomas Klijnsma Date: Wed, 11 Oct 2023 15:52:50 -0500 Subject: [PATCH 1/5] c++ extension for object condensation loss function --- extensions/oc_cpu.cpp | 345 ++++++++++++++++++++++++++++++++++++++ scripts/performance_oc.py | 73 ++++++++ setup.py | 3 +- tests/test_oc.py | 180 ++++++++++++++++++++ torch_cmspepr/__init__.py | 5 +- torch_cmspepr/oc.py | 48 ++++++ 6 files changed, 651 insertions(+), 3 deletions(-) create mode 100644 extensions/oc_cpu.cpp create mode 100644 scripts/performance_oc.py create mode 100644 tests/test_oc.py create mode 100644 torch_cmspepr/oc.py diff --git a/extensions/oc_cpu.cpp b/extensions/oc_cpu.cpp new file mode 100644 index 0000000..1178326 --- /dev/null +++ b/extensions/oc_cpu.cpp @@ -0,0 +1,345 @@ +#include + +// #include //size_t, just for helper function +#include +// #include + +#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") +#define I2D(i,j,Nj) j + Nj*i + +/* +Returns the squared distance between two nodes in clustering space. +*/ +float calc_dist_sq( + const size_t i, // index of node i + const size_t j, // index of node j + const float_t *x, // node feature matrix + const size_t ndim // number of dimensions + ){ + float_t distsq = 0; + if (i == j) return 0; + // std::cout << "dist_sq i=" << i << " j=" << j << std::endl; + for (size_t idim = 0; idim < ndim; idim++) { + float_t dist = x[I2D(i,idim,ndim)] - x[I2D(j,idim,ndim)]; + // std::cout + // << " idim=" << idim + // << " x[" << i << "][" << idim << "]=" << x[I2D(i,idim,ndim)] + // << " x[" << j << "][" << idim << "]=" << x[I2D(j,idim,ndim)] + // << " d=" << dist + // << " d_sq=" << dist*dist + // << std::endl; + distsq += dist * dist; + } + // std::cout << " d_sq_sum=" << distsq << std::endl; + return distsq; + } + + +void oc_kernel( + // Global event info + const float_t* beta, // beta per node + const float_t* q, // charge per node + const float_t* x, // cluster space coordinates + const size_t n_dim_cluster_space, // Number of dimensions of the cluster space + const int32_t* cond_indices, // indices of the condensation points + const int32_t* cond_counts, // nr of nodes connected to the cond point + const size_t cond_indices_start, // row split start for cond points + const size_t cond_indices_end, // row split end for cond points + const int32_t* which_cond_point, // (n_nodes,) array pointing to the cond point index + const int32_t n_nodes, // Number of nodes in the event of this node + + // To be parallellized over + const size_t i_node, // index of the node in question + + // Outputs: + float_t * V_att, + float_t * V_rep, + float_t * V_srp + ){ + + int32_t i_cond = which_cond_point[i_node]; + + // std::cout + // << "i_node=" << i_node + // << " i_cond=" << i_cond + // << " q[i_node]=" << q[i_node] + // << " cond_start=" << cond_indices_start + // << " cond_end=" << cond_indices_end + // << " n_nodes=" << n_nodes + // << std::endl; + + // V_att and V_srp + if (i_cond == -1 || i_node == (size_t)i_cond){ + // Noise node, or a condensation point itself + // std::cout << " Noise hit or cond point, V_att/V_srp set to 0." << std::endl; + *V_att = 0.; + *V_srp = 0.; + } + else { + float d_sq = calc_dist_sq(i_node, i_cond, x, n_dim_cluster_space); + float d = sqrt(d_sq); + float_t d_huber = d+0.00001 <= 4.0 ? d_sq : 2.0 * 4.0 * (d - 4.0) ; + *V_att = d_huber * q[i_node] * q[i_cond] / (float)n_nodes; + // V_srp must still be normalized! This is done in the V_rep loop because the + // normalization numbers are easier to access there. + *V_srp = 1. / (20.*d_sq + 1.); + // std::cout << " d_huber for i_node " << i_node << ": " + // << d_huber + // << "; d_sq=" << d_sq + // << "; V_att=" << *V_att + // << "; V_srp=" << *V_srp + // << std::endl; + } + + // V_rep + *V_rep = 0.; + for (size_t i=cond_indices_start; i q_max[y_node-1]=" << q_max[y_node-1] + // << "\n Updating i_max[y_node-1] to " << i_node + // << std::endl; + q_max[y_node-1] = q[i_node]; + i_max[y_node-1] = i_node; + } + } + + // Loop over nodes in event, use i_max to determine per node to which + // cond point it belongs + for (int32_t i_node=row_splits[i_event]; i_node(); + + + float* V_att = (float *)malloc(n_nodes * sizeof(float)); + float* V_rep = (float *)malloc(n_nodes * sizeof(float)); + float* V_srp = (float *)malloc(n_nodes * sizeof(float)); + + // Loop over events + for (size_t i_event=0; i_event0) + L_beta_noise += L_beta_noise_this_event / (float)n_noise_this_event ; + } + losses[3] = L_beta_cond_logterm / (float)n_events; + losses[4] = L_beta_noise / (float)n_events; + + free(n_cond_per_event); + free(cond_indices_row_splits); + free(cond_indices); + free(cond_counts); + free(which_cond_point); + + float V_att_sum = 0.; + float V_rep_sum = 0.; + float V_srp_sum = 0.; + for (size_t i_node=0; i_node Date: Wed, 11 Oct 2023 15:59:24 -0500 Subject: [PATCH 2/5] formatted and linted --- scripts/performance_oc.py | 35 +++++---- tests/test_oc.py | 150 ++++++++++++++++++++------------------ torch_cmspepr/oc.py | 11 +-- 3 files changed, 106 insertions(+), 90 deletions(-) diff --git a/scripts/performance_oc.py b/scripts/performance_oc.py index e6d6a51..542eb1d 100644 --- a/scripts/performance_oc.py +++ b/scripts/performance_oc.py @@ -10,26 +10,29 @@ def make_random_event(n_nodes=10000, n_events=5): model_out = torch.rand((n_nodes, 32)) # Varying event sizes - event_fracs = torch.normal(torch.ones(n_events), .1) + event_fracs = torch.normal(torch.ones(n_events), 0.1) event_fracs /= event_fracs.sum() event_sizes = (event_fracs * n_nodes).type(torch.int) - event_sizes[-1] += n_nodes - event_sizes.sum() # Make sure it adds up to n_nodes + event_sizes[-1] += n_nodes - event_sizes.sum() # Make sure it adds up to n_nodes batch = torch.arange(n_events).repeat_interleave(event_sizes) - row_splits = torch.cat((torch.zeros(1, dtype=torch.int), torch.cumsum(event_sizes, 0))) + row_splits = torch.cat( + (torch.zeros(1, dtype=torch.int), torch.cumsum(event_sizes, 0)) + ) ys = [] for i_event in range(n_events): - n_clusters = torch.randint(3, 8, (1,)).item() # Somewhere between 3 and 8 particles + # Somewhere between 3 and 8 particles + n_clusters = torch.randint(3, 8, (1,)).item() cluster_fracs = torch.randint(50, 200, (n_clusters,)).type(torch.float) - cluster_fracs[0] += 200 # Boost the amount of noise relatively + cluster_fracs[0] += 200 # Boost the amount of noise relatively cluster_fracs /= cluster_fracs.sum() cluster_sizes = (cluster_fracs * event_sizes[i_event]).type(torch.int) # Make sure it adds up to n_nodes in this event cluster_sizes[-1] += event_sizes[i_event] - cluster_sizes.sum() ys.append(torch.arange(n_clusters).repeat_interleave(cluster_sizes)) y = torch.cat(ys) - + y = y.type(torch.int) row_splits = row_splits.type(torch.int) return model_out, y, batch, row_splits @@ -43,18 +46,20 @@ def test_oc_performance(): return objectcondensation.ObjectCondensation.beta_term_option = 'short_range_potential' - objectcondensation.ObjectCondensation.sB = 1. + objectcondensation.ObjectCondensation.sB = 1.0 - t_py = 0. - t_cpp = 0. + t_py = 0.0 + t_cpp = 0.0 N = 1000 for i_test in tqdm.tqdm(range(N)): # Don't count prep work in performance model_out, y, batch, row_splits = make_random_event() data = Data(y=y.type(torch.long), batch=batch) - beta = torch.sigmoid(model_out[:,0]).contiguous() - q = objectcondensation.calc_q_betaclip(torch.sigmoid(model_out[:,0])).contiguous() - x = model_out[:,1:].contiguous() + beta = torch.sigmoid(model_out[:, 0]).contiguous() + q = objectcondensation.calc_q_betaclip( + torch.sigmoid(model_out[:, 0]) + ).contiguous() + x = model_out[:, 1:].contiguous() t0 = time.perf_counter() objectcondensation.oc_loss(model_out, data) @@ -62,12 +67,12 @@ def test_oc_performance(): torch_cmspepr.oc(beta, q, x, y, batch) t2 = time.perf_counter() - t_py += t1-t0 - t_cpp += t2-t1 + t_py += t1 - t0 + t_cpp += t2 - t1 print(f'Average python time: {t_py/N:.4f}') print(f'Average cpp time: {t_cpp/N:.4f}') print(f'Speed up is {t_py/t_cpp:.2f}x') -test_oc_performance() \ No newline at end of file +test_oc_performance() diff --git a/tests/test_oc.py b/tests/test_oc.py index d348acb..dd56334 100644 --- a/tests/test_oc.py +++ b/tests/test_oc.py @@ -1,5 +1,5 @@ import os.path as osp -from math import sqrt, log +from math import log import torch from torch_geometric.data import Data @@ -7,65 +7,67 @@ SO_DIR = osp.dirname(osp.dirname(osp.abspath(__file__))) -def calc_q_betaclip(beta, qmin=1.): +def calc_q_betaclip(beta, qmin=1.0): return (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin # Single event class single: + # fmt: off model_out = torch.FloatTensor([ - # Event 0 - # beta x0 x1 y - [0.01, 0.40, 0.40], # 0 - [0.02, 0.10, 0.90], # 0 - [0.12, 0.70, 0.70], # 1 <- d_sq to cond point = 0.02^2 + 0.02^2 = 0.0008; d=0.0283 - [0.01, 0.90, 0.10], # 0 - [0.13, 0.72, 0.72], # 1 <-- cond point for y=1 + # Event 0 + # beta x0 x1 y + [0.01, 0.40, 0.40], # 0 + [0.02, 0.10, 0.90], # 0 + [0.12, 0.70, 0.70], # 1 <- d_sq to cond point = 0.02^2 + 0.02^2 = 0.0008; d=0.0283 + [0.01, 0.90, 0.10], # 0 + [0.13, 0.72, 0.72], # 1 <-- cond point for y=1 ]) - x = model_out[:,1:].contiguous() + # fmt: on + x = model_out[:, 1:].contiguous() y = torch.LongTensor([0, 0, 1, 0, 1]) batch = torch.LongTensor([0, 0, 0, 0, 0]) - beta = torch.sigmoid(model_out[:,0]).contiguous() + beta = torch.sigmoid(model_out[:, 0]).contiguous() q = calc_q_betaclip(beta) @classmethod def d(cls, i, j): - return ((cls.x[i]-cls.x[j])**2).sum() + return ((cls.x[i] - cls.x[j]) ** 2).sum() # Manual OC: @classmethod def losses(cls): - x = single.x - y = single.y beta = single.beta q = single.q d = single.d - V_att = d(2,4) * cls.q[2] * q[4] / 5. # Since d is small, d == d_huber + V_att = d(2, 4) * q[2] * q[4] / 5.0 # Since d is small, d == d_huber V_rep = ( - torch.exp(-4.*d(0,4)) * q[0] * q[4] - + torch.exp(-4.*d(1,4)) * q[1] * q[4] - + torch.exp(-4.*d(3,4)) * q[3] * q[4] - ) / 5. - V_srp = -1./(20.*d(2,4) + 1.) * beta[4] / 2. - L_beta_cond_logterm = -0.2 * log(beta[4]+1e-9) - L_beta_noise = (beta[0]+beta[1]+beta[3]) / 3. - - losses_man = torch.FloatTensor([V_att, V_rep, V_srp, L_beta_cond_logterm, L_beta_noise]) + torch.exp(-4.0 * d(0, 4)) * q[0] * q[4] + + torch.exp(-4.0 * d(1, 4)) * q[1] * q[4] + + torch.exp(-4.0 * d(3, 4)) * q[3] * q[4] + ) / 5.0 + V_srp = -1.0 / (20.0 * d(2, 4) + 1.0) * beta[4] / 2.0 + L_beta_cond_logterm = -0.2 * log(beta[4] + 1e-9) + L_beta_noise = (beta[0] + beta[1] + beta[3]) / 3.0 + + losses_man = torch.FloatTensor( + [V_att, V_rep, V_srp, L_beta_cond_logterm, L_beta_noise] + ) return losses_man def test_oc_cpu_single(): torch.ops.load_library(osp.join(SO_DIR, 'oc_cpu.so')) - + losses_cpp = torch.ops.oc_cpu.oc_cpu( single.beta, single.q, single.x, single.y.type(torch.int), torch.IntTensor([0, 5]), - ) - + ) + losses_man = single.losses() print(f'{losses_man=}') print(f'{losses_cpp=}') @@ -74,13 +76,8 @@ def test_oc_cpu_single(): def test_oc_python_single(): import torch_cmspepr - losses = torch_cmspepr.oc( - single.beta, - single.q, - single.x, - single.y, - single.batch - ) + + losses = torch_cmspepr.oc(single.beta, single.q, single.x, single.y, single.batch) losses_man = single.losses() print(f'{losses_man=}') print(f'{losses=}') @@ -88,24 +85,23 @@ def test_oc_python_single(): class multiple: - model_out = torch.FloatTensor( - [ - # Event 0 - # beta x0 x1 idx y - [0.01, 0.40, 0.40], # 0 0 - [0.02, 0.10, 0.90], # 1 0 - [0.12, 0.70, 0.70], # 2 1 <- d_sq to cond point = 0.02^2 + 0.02^2 = 0.0008; d=0.0283 - [0.01, 0.90, 0.10], # 3 0 - [0.13, 0.72, 0.72], # 4 1 <-- cond point for y=1 - # Event 1 - [0.11, 0.40, 0.40], # 5 2 - [0.02, 0.10, 0.90], # 6 0 - [0.12, 0.70, 0.70], # 7 1 <-- cond point for y=1 - [0.01, 0.90, 0.10], # 8 0 - [0.13, 0.72, 0.72], # 9 2 <-- cond point for y=2 - [0.11, 0.72, 0.72], # 10 1 - ] - ) + # fmt: off + model_out = torch.FloatTensor([ + # Event 0 + # beta x0 x1 idx y + [0.01, 0.40, 0.40], # 0 0 + [0.02, 0.10, 0.90], # 1 0 + [0.12, 0.70, 0.70], # 2 1 <- d_sq to cond point = 0.02^2 + 0.02^2 = 0.0008; d=0.0283 + [0.01, 0.90, 0.10], # 3 0 + [0.13, 0.72, 0.72], # 4 1 <-- cond point for y=1 + # Event 1 + [0.11, 0.40, 0.40], # 5 2 + [0.02, 0.10, 0.90], # 6 0 + [0.12, 0.70, 0.70], # 7 1 <-- cond point for y=1 + [0.01, 0.90, 0.10], # 8 0 + [0.13, 0.72, 0.72], # 9 2 <-- cond point for y=2 + [0.11, 0.72, 0.72], # 10 1 + ]) x = model_out[:,1:].contiguous() y = torch.LongTensor([ 0, 0, 1, 0, 1, # Event 0 @@ -115,8 +111,9 @@ class multiple: 0, 0, 0, 0, 0, # Event 0 1, 1, 1, 1, 1, 1 # Event 1 ]) + # fmt: on row_splits = torch.IntTensor([0, 5, 11]) - beta = torch.sigmoid(model_out[:,0]).contiguous() + beta = torch.sigmoid(model_out[:, 0]).contiguous() q = calc_q_betaclip(beta).contiguous() @@ -130,20 +127,26 @@ def test_oc_cpu_batch(): objectcondensation.ObjectCondensation.beta_term_option = 'short_range_potential' objectcondensation.ObjectCondensation.sB = 1.0 - - loss_py = objectcondensation.oc_loss(multiple.model_out, Data(y=multiple.y, batch=multiple.batch)) - losses_py = torch.FloatTensor([ - loss_py["V_att"], loss_py["V_rep"], - loss_py["L_beta_sig"], loss_py["L_beta_cond_logterm"], - loss_py["L_beta_noise"] - ]) + + loss_py = objectcondensation.oc_loss( + multiple.model_out, Data(y=multiple.y, batch=multiple.batch) + ) + losses_py = torch.FloatTensor( + [ + loss_py["V_att"], + loss_py["V_rep"], + loss_py["L_beta_sig"], + loss_py["L_beta_cond_logterm"], + loss_py["L_beta_noise"], + ] + ) losses_cpp = torch.ops.oc_cpu.oc_cpu( multiple.beta, multiple.q, multiple.x, multiple.y.type(torch.int), multiple.row_splits, - ) + ) print(losses_py) print(losses_cpp) # Lots of rounding errors in python vs c++, can't compare too rigorously @@ -152,28 +155,35 @@ def test_oc_cpu_batch(): def test_oc_python_batch(): import torch_cmspepr + try: import cmspepr_hgcal_core.objectcondensation as objectcondensation except ImportError: print('Install cmspepr_hgcal_core to run this test') return - + objectcondensation.ObjectCondensation.beta_term_option = 'short_range_potential' objectcondensation.ObjectCondensation.sB = 1.0 - loss_py = objectcondensation.oc_loss(multiple.model_out, Data(y=multiple.y, batch=multiple.batch)) - losses_py = torch.FloatTensor([ - loss_py["V_att"], loss_py["V_rep"], - loss_py["L_beta_sig"], loss_py["L_beta_cond_logterm"], - loss_py["L_beta_noise"] - ]) + loss_py = objectcondensation.oc_loss( + multiple.model_out, Data(y=multiple.y, batch=multiple.batch) + ) + losses_py = torch.FloatTensor( + [ + loss_py["V_att"], + loss_py["V_rep"], + loss_py["L_beta_sig"], + loss_py["L_beta_cond_logterm"], + loss_py["L_beta_noise"], + ] + ) losses = torch_cmspepr.oc( multiple.beta, multiple.q, multiple.x, multiple.y.type(torch.int), - multiple.batch - ) + multiple.batch, + ) print(losses_py) print(losses) # Lots of rounding errors in python vs c++, can't compare too rigorously diff --git a/torch_cmspepr/oc.py b/torch_cmspepr/oc.py index ba791d6..67fdbef 100644 --- a/torch_cmspepr/oc.py +++ b/torch_cmspepr/oc.py @@ -1,14 +1,15 @@ import torch + # @torch.jit.script def oc( beta: torch.FloatTensor, q: torch.FloatTensor, x: torch.FloatTensor, - y: torch.LongTensor, # Use long for consistency - batch: torch.LongTensor, # Use long for consistency - sB: float = 1. - ): + y: torch.LongTensor, # Use long for consistency + batch: torch.LongTensor, # Use long for consistency + sB: float = 1.0, +): """ Calculate the object condensation loss function. @@ -40,7 +41,7 @@ def oc( assert device == torch.device('cpu') # Translate batch vector into row splits - counts = torch.zeros(batch.max()+1, dtype=torch.int, device=device) + counts = torch.zeros(batch.max() + 1, dtype=torch.int, device=device) counts.scatter_add_(0, batch, torch.ones_like(batch, dtype=torch.int)) counts = torch.cat((torch.zeros(1, dtype=torch.int, device=device), counts)) row_splits = torch.cumsum(counts, 0).type(torch.int) From 0c87d1468dbd435d99db4781426c7d21e011434d Mon Sep 17 00:00:00 2001 From: Thomas Klijnsma Date: Wed, 11 Oct 2023 16:03:50 -0500 Subject: [PATCH 3/5] updated readme --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1d7a70a..7fc1f88 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # pytorch_cmspepr -pytorch bindings for optimized knn and aggregation kernels +pytorch bindings for optimized knn and aggregation kernels. + +Now also has a C++ extension for the [Object Condensation](https://arxiv.org/abs/2002.03605) loss function. ## Example @@ -121,7 +123,7 @@ pytest tests ## Performance -The following profiling code can be used: +The following profiling code can be used (see the script [performance.py](scripts/performance.py)): ```python import time @@ -171,4 +173,7 @@ CPU (torch_cmspepr) took 0.22623349189758302 sec/evt CPU (torch_cluster) took 0.2259768319129944 sec/evt CUDA (torch_cmspepr) took 0.026673252582550048 sec/evt CUDA (torch_cluster) took 0.22262062072753908 sec/evt -``` \ No newline at end of file +``` + +Similarly, there is a profiling script available for object condensation, see [performance_oc.py](scripts/performance_oc.py). +Here a 3x speed up is achieved w.r.t. to the pure-Python implementation of object condensation, but more importantly, memory consumption is drastically reduced. From 9354cadca893a72dec9370204fa18d60d2cb594c Mon Sep 17 00:00:00 2001 From: Thomas Klijnsma Date: Wed, 11 Oct 2023 16:09:12 -0500 Subject: [PATCH 4/5] version bump --- setup.py | 2 +- torch_cmspepr/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 2f78c42..a98a717 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ def repr_ext(ext): tests_require = ['pytest', 'pytest-cov', 'scipy'] setup( name='torch_cmspepr', - version='1.0.0', + version='1.1.0', author='Lindsey Gray , Jan Kieseler , Thomas Klijnsma ', author_email='Lindsey.Gray@cern.ch', url='', diff --git a/torch_cmspepr/__init__.py b/torch_cmspepr/__init__.py index 3497964..8ddbad9 100644 --- a/torch_cmspepr/__init__.py +++ b/torch_cmspepr/__init__.py @@ -3,7 +3,7 @@ import logging import torch -__version__ = '1.0.0' +__version__ = '1.1.0' def setup_logger(name: str = "cmspepr") -> logging.Logger: From 94088cee31e64a979e30fa1df40a899cc92ecf68 Mon Sep 17 00:00:00 2001 From: Thomas Klijnsma Date: Wed, 11 Oct 2023 16:11:23 -0500 Subject: [PATCH 5/5] reenabled jit for oc python interface --- torch_cmspepr/oc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_cmspepr/oc.py b/torch_cmspepr/oc.py index 67fdbef..607b249 100644 --- a/torch_cmspepr/oc.py +++ b/torch_cmspepr/oc.py @@ -1,7 +1,7 @@ import torch -# @torch.jit.script +@torch.jit.script def oc( beta: torch.FloatTensor, q: torch.FloatTensor,