diff --git a/README.md b/README.md index 1d7a70a..7fc1f88 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # pytorch_cmspepr -pytorch bindings for optimized knn and aggregation kernels +pytorch bindings for optimized knn and aggregation kernels. + +Now also has a C++ extension for the [Object Condensation](https://arxiv.org/abs/2002.03605) loss function. ## Example @@ -121,7 +123,7 @@ pytest tests ## Performance -The following profiling code can be used: +The following profiling code can be used (see the script [performance.py](scripts/performance.py)): ```python import time @@ -171,4 +173,7 @@ CPU (torch_cmspepr) took 0.22623349189758302 sec/evt CPU (torch_cluster) took 0.2259768319129944 sec/evt CUDA (torch_cmspepr) took 0.026673252582550048 sec/evt CUDA (torch_cluster) took 0.22262062072753908 sec/evt -``` \ No newline at end of file +``` + +Similarly, there is a profiling script available for object condensation, see [performance_oc.py](scripts/performance_oc.py). +Here a 3x speed up is achieved w.r.t. to the pure-Python implementation of object condensation, but more importantly, memory consumption is drastically reduced. diff --git a/extensions/oc_cpu.cpp b/extensions/oc_cpu.cpp new file mode 100644 index 0000000..1178326 --- /dev/null +++ b/extensions/oc_cpu.cpp @@ -0,0 +1,345 @@ +#include + +// #include //size_t, just for helper function +#include +// #include + +#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") +#define I2D(i,j,Nj) j + Nj*i + +/* +Returns the squared distance between two nodes in clustering space. +*/ +float calc_dist_sq( + const size_t i, // index of node i + const size_t j, // index of node j + const float_t *x, // node feature matrix + const size_t ndim // number of dimensions + ){ + float_t distsq = 0; + if (i == j) return 0; + // std::cout << "dist_sq i=" << i << " j=" << j << std::endl; + for (size_t idim = 0; idim < ndim; idim++) { + float_t dist = x[I2D(i,idim,ndim)] - x[I2D(j,idim,ndim)]; + // std::cout + // << " idim=" << idim + // << " x[" << i << "][" << idim << "]=" << x[I2D(i,idim,ndim)] + // << " x[" << j << "][" << idim << "]=" << x[I2D(j,idim,ndim)] + // << " d=" << dist + // << " d_sq=" << dist*dist + // << std::endl; + distsq += dist * dist; + } + // std::cout << " d_sq_sum=" << distsq << std::endl; + return distsq; + } + + +void oc_kernel( + // Global event info + const float_t* beta, // beta per node + const float_t* q, // charge per node + const float_t* x, // cluster space coordinates + const size_t n_dim_cluster_space, // Number of dimensions of the cluster space + const int32_t* cond_indices, // indices of the condensation points + const int32_t* cond_counts, // nr of nodes connected to the cond point + const size_t cond_indices_start, // row split start for cond points + const size_t cond_indices_end, // row split end for cond points + const int32_t* which_cond_point, // (n_nodes,) array pointing to the cond point index + const int32_t n_nodes, // Number of nodes in the event of this node + + // To be parallellized over + const size_t i_node, // index of the node in question + + // Outputs: + float_t * V_att, + float_t * V_rep, + float_t * V_srp + ){ + + int32_t i_cond = which_cond_point[i_node]; + + // std::cout + // << "i_node=" << i_node + // << " i_cond=" << i_cond + // << " q[i_node]=" << q[i_node] + // << " cond_start=" << cond_indices_start + // << " cond_end=" << cond_indices_end + // << " n_nodes=" << n_nodes + // << std::endl; + + // V_att and V_srp + if (i_cond == -1 || i_node == (size_t)i_cond){ + // Noise node, or a condensation point itself + // std::cout << " Noise hit or cond point, V_att/V_srp set to 0." << std::endl; + *V_att = 0.; + *V_srp = 0.; + } + else { + float d_sq = calc_dist_sq(i_node, i_cond, x, n_dim_cluster_space); + float d = sqrt(d_sq); + float_t d_huber = d+0.00001 <= 4.0 ? d_sq : 2.0 * 4.0 * (d - 4.0) ; + *V_att = d_huber * q[i_node] * q[i_cond] / (float)n_nodes; + // V_srp must still be normalized! This is done in the V_rep loop because the + // normalization numbers are easier to access there. + *V_srp = 1. / (20.*d_sq + 1.); + // std::cout << " d_huber for i_node " << i_node << ": " + // << d_huber + // << "; d_sq=" << d_sq + // << "; V_att=" << *V_att + // << "; V_srp=" << *V_srp + // << std::endl; + } + + // V_rep + *V_rep = 0.; + for (size_t i=cond_indices_start; i q_max[y_node-1]=" << q_max[y_node-1] + // << "\n Updating i_max[y_node-1] to " << i_node + // << std::endl; + q_max[y_node-1] = q[i_node]; + i_max[y_node-1] = i_node; + } + } + + // Loop over nodes in event, use i_max to determine per node to which + // cond point it belongs + for (int32_t i_node=row_splits[i_event]; i_node(); + + + float* V_att = (float *)malloc(n_nodes * sizeof(float)); + float* V_rep = (float *)malloc(n_nodes * sizeof(float)); + float* V_srp = (float *)malloc(n_nodes * sizeof(float)); + + // Loop over events + for (size_t i_event=0; i_event0) + L_beta_noise += L_beta_noise_this_event / (float)n_noise_this_event ; + } + losses[3] = L_beta_cond_logterm / (float)n_events; + losses[4] = L_beta_noise / (float)n_events; + + free(n_cond_per_event); + free(cond_indices_row_splits); + free(cond_indices); + free(cond_counts); + free(which_cond_point); + + float V_att_sum = 0.; + float V_rep_sum = 0.; + float V_srp_sum = 0.; + for (size_t i_node=0; i_node logging.Logger: @@ -49,8 +49,9 @@ def load_ops(so_file): THISDIR = osp.dirname(osp.abspath(__file__)) load_ops(osp.join(THISDIR, "../select_knn_cpu.so")) load_ops(osp.join(THISDIR, "../select_knn_cuda.so")) - +load_ops(osp.join(THISDIR, "../oc_cpu.so")) from torch_cmspepr.select_knn import select_knn, knn_graph +from torch_cmspepr.oc import oc -__all__ = ['select_knn', 'knn_graph', 'logger'] +__all__ = ['select_knn', 'knn_graph', 'oc', 'logger'] diff --git a/torch_cmspepr/oc.py b/torch_cmspepr/oc.py new file mode 100644 index 0000000..607b249 --- /dev/null +++ b/torch_cmspepr/oc.py @@ -0,0 +1,49 @@ +import torch + + +@torch.jit.script +def oc( + beta: torch.FloatTensor, + q: torch.FloatTensor, + x: torch.FloatTensor, + y: torch.LongTensor, # Use long for consistency + batch: torch.LongTensor, # Use long for consistency + sB: float = 1.0, +): + """ + Calculate the object condensation loss function. + + Args: + beta (torch.FloatTensor): Beta as described in https://arxiv.org/abs/2002.03605; + simply a sigmoid of the raw model output + q (torch.FloatTensor): Charge q per node; usually a function of beta. + x (torch.FloatTensor): Latent clustering space coordinates for every node. + y (torch.LongTensor): Clustering truth. WARNING: The torch.op expects y to be + nicely *incremental*. There should not be any holes in it. + batch (torch.LongTensor): Batch vector to designate event boundaries. WARNING: + It is expected that batch is *sorted*. + + Returns: + torch.FloatTensor: A len-5 tensor with the 5 loss components of the OC loss + function: V_att, V_rep, V_srp, L_beta_cond_logterm, and L_beta_noise. The + full OC loss is simply the sum of this tensor. + """ + N = beta.size(0) + assert beta.dim() == 1 + assert q.dim() == 1 + assert beta.size() == q.size() + assert x.size(0) == N + assert y.size(0) == N + assert batch.size(0) == N + device = beta.device + + # TEMPORARY: No GPU version available yet. + assert device == torch.device('cpu') + + # Translate batch vector into row splits + counts = torch.zeros(batch.max() + 1, dtype=torch.int, device=device) + counts.scatter_add_(0, batch, torch.ones_like(batch, dtype=torch.int)) + counts = torch.cat((torch.zeros(1, dtype=torch.int, device=device), counts)) + row_splits = torch.cumsum(counts, 0).type(torch.int) + + return torch.ops.oc_cpu.oc_cpu(beta, q, x, y.type(torch.int), row_splits)