Skip to content

Commit

Permalink
make pointgroup_ops compatible with pytorch 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
llijiang committed Jun 9, 2020
1 parent dea7ab1 commit a41c2cd
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 24 deletions.
6 changes: 3 additions & 3 deletions lib/pointgroup_ops/functions/pointgroup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive):

n = coords.size(0)

assert coords.is_contiguous()
assert batch_idxs.is_contiguous()
assert batch_offsets.is_contiguous()
assert coords.is_contiguous() and coords.is_cuda
assert batch_idxs.is_contiguous() and batch_idxs.is_cuda
assert batch_offsets.is_contiguous() and batch_offsets.is_cuda

while True:
idx = torch.cuda.IntTensor(n * meanActive).zero_()
Expand Down
12 changes: 1 addition & 11 deletions lib/pointgroup_ops/src/bfs_cluster/bfs_cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,20 @@ All Rights Reserved 2020.

#include "bfs_cluster.h"

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

extern THCState *state;

/* ================================== ballquery_batch_p ================================== */
// input xyz: (n, 3) float
// input batch_idxs: (n) int
// input batch_offsets: (B+1) int, batch_offsets[-1]
// output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n
// output start_len: (n, 2), int
int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius){
CHECK_INPUT(xyz_tensor);
CHECK_INPUT(batch_idxs_tensor);
CHECK_INPUT(batch_offsets_tensor);

const float *xyz = xyz_tensor.data<float>();
const int *batch_idxs = batch_idxs_tensor.data<int>();
const int *batch_offsets = batch_offsets_tensor.data<int>();
int *idx = idx_tensor.data<int>();
int *start_len = start_len_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int cumsum = ballquery_batch_p_cuda(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, stream);
return cumsum;
}
Expand Down
4 changes: 0 additions & 4 deletions lib/pointgroup_ops/src/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

#define TOTAL_THREADS 1024

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

#define THREADS_PER_BLOCK 512
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))

Expand Down
6 changes: 0 additions & 6 deletions lib/pointgroup_ops/src/get_iou/get_iou.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ All Rights Reserved 2020.

#include "get_iou.h"

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

extern THCState *state;

void get_iou(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, at::Tensor proposals_iou_tensor, int nInstance, int nProposal){
int *proposals_idx = proposals_idx_tensor.data<int>();
int *proposals_offset = proposals_offset_tensor.data<int>();
Expand Down

0 comments on commit a41c2cd

Please sign in to comment.