Skip to content

Commit

Permalink
OP ROIPooling CPU fix and DType support (apache#3011)
Browse files Browse the repository at this point in the history
* OP ROIPooling CPU fix

* OP ROIPooling DType support
  • Loading branch information
precedenceguo authored and winstywang committed Aug 15, 2016
1 parent 15d1727 commit 3a05d99
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 36 deletions.
42 changes: 31 additions & 11 deletions src/operator/roi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ROIPoolingParam : public dmlc::Parameter<ROIPoolingParam> {
}
};

template<typename xpu>
template<typename xpu, typename DType>
class ROIPoolingOp : public Operator {
public:
explicit ROIPoolingOp(ROIPoolingParam p) {
Expand All @@ -61,10 +61,10 @@ class ROIPoolingOp : public Operator {
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]);
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 4> data = in_data[roipool::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 2> bbox = in_data[roipool::kBox].get<xpu, 2, real_t>(s);
Tensor<xpu, 4> out = out_data[roipool::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> data = in_data[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> out = out_data[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
CHECK_EQ(data.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand All @@ -90,10 +90,10 @@ class ROIPoolingOp : public Operator {
CHECK_EQ(req[roipool::kOut], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 4> grad_out = out_grad[roipool::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 2> bbox = in_data[roipool::kBox].get<xpu, 2, real_t>(s);
Tensor<xpu, 4> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> grad_in = in_grad[roipool::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> grad_out = out_grad[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad_in = in_grad[roipool::kData].get<xpu, 4, DType>(s);
CHECK_EQ(grad_out.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(max_idx.CheckContiguous(), true);
Expand All @@ -108,7 +108,7 @@ class ROIPoolingOp : public Operator {

// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(ROIPoolingParam param);
Operator* CreateOp(ROIPoolingParam param, int dtype);

#if DMLC_USE_CXX11
class ROIPoolingProp : public OperatorProperty {
Expand Down Expand Up @@ -162,6 +162,20 @@ class ROIPoolingProp : public OperatorProperty {
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 2);
int dtype = (*in_type)[0];
CHECK_EQ(dtype, (*in_type)[1]);
CHECK_NE(dtype, -1) << "Input must have specified type";

out_type->clear();
out_type->push_back(dtype);
out_type->push_back(dtype);
return true;
}

OperatorProperty* Copy() const override {
ROIPoolingProp* roi_pooling_sym = new ROIPoolingProp();
roi_pooling_sym->param_ = this->param_;
Expand All @@ -180,7 +194,13 @@ class ROIPoolingProp : public OperatorProperty {
return {out_grad[roipool::kOut], in_data[roipool::kBox], out_data[roipool::kMaxIdx]};
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
ROIPoolingParam param_;
Expand Down
34 changes: 20 additions & 14 deletions src/operator/roi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,23 @@ inline void ROIPoolBackward(const Tensor<cpu, 4, Dtype> &in_grad,
for (int h = 0; h < height_; ++h) {
for (int w = 0; w < width_; ++w) {
int offset_bottom_diff = (b * channels_ + c) * height_ * width_;
offset_bottom_diff += h * height_ + w;
offset_bottom_diff += h * width_ + w;

Dtype gradient = 0;
// Accumulate gradient over all ROIs that pooled this element
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
int roi_batch_ind = bottom_rois[0];
const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5;
int roi_batch_ind = offset_bottom_rois[0];
assert(roi_batch_ind >= 0);
assert(roi_batch_ind < batch_size_);
if (b != roi_batch_ind) {
continue;
}

int roi_start_w = round(bottom_rois[1] * spatial_scale_);
int roi_start_h = round(bottom_rois[2] * spatial_scale_);
int roi_end_w = round(bottom_rois[3] * spatial_scale_);
int roi_end_h = round(bottom_rois[4] * spatial_scale_);
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale_);
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale_);
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale_);
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale_);

bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
h >= roi_start_h && h <= roi_end_h);
Expand Down Expand Up @@ -191,9 +192,6 @@ inline void ROIPoolBackward(const Tensor<cpu, 4, Dtype> &in_grad,
}
}
}

// Increment ROI data pointer
bottom_rois += bbox.size(1);
}
bottom_diff[offset_bottom_diff] = gradient;
}
Expand All @@ -209,13 +207,21 @@ namespace mxnet {
namespace op {

template<>
Operator* CreateOp<cpu>(ROIPoolingParam param) {
return new ROIPoolingOp<cpu>(param);
Operator *CreateOp<cpu>(ROIPoolingParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new ROIPoolingOp<cpu, DType>(param);
});
return op;
}

// DO_BIND_DISPATCH comes from static_operator_common.h
Operator* ROIPoolingProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator *ROIPoolingProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}

DMLC_REGISTER_PARAMETER(ROIPoolingParam);
Expand Down
18 changes: 7 additions & 11 deletions src/operator/roi_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include <algorithm>
#include <vector>

#define ROIPOOLING_CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
} while (0)

namespace mshadow {
namespace cuda {

Expand Down Expand Up @@ -117,7 +110,6 @@ inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
ROIPoolForwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
count, bottom_data, spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_rois, top_data, argmax_data);
ROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
}

template<typename Dtype>
Expand Down Expand Up @@ -221,7 +213,6 @@ inline void ROIPoolBackward(const Tensor<gpu, 4, Dtype> &in_grad,
ROIPoolBackwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
count, top_diff, argmax_data, num_rois, spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_diff, bottom_rois);
ROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
}

} // namespace cuda
Expand Down Expand Up @@ -251,8 +242,13 @@ namespace mxnet {
namespace op {

template<>
Operator* CreateOp<gpu>(ROIPoolingParam param) {
return new ROIPoolingOp<gpu>(param);
Operator* CreateOp<gpu>(ROIPoolingParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new ROIPoolingOp<gpu, DType>(param);
});
return op;
}

} // namespace op
} // namespace mxnet
10 changes: 10 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,15 @@ def test_support_vector_machine_l2_svm():
grad_np = grad_np.astype(np.float32)
assert_allclose(grad_np, grad.asnumpy())

def test_roipooling():
data = mx.symbol.Variable(name='data')
rois = mx.symbol.Variable(name='rois')
test = mx.symbol.ROIPooling(data=data, rois=rois, pooled_size=(6, 6), spatial_scale=1)

x1 = np.random.rand(4, 3, 12, 8)
x2 = np.array([[0, 1, 1, 6, 6], [2, 6, 2, 7, 11], [1, 3, 1, 5, 10], [0, 3, 3, 3, 3]])

check_numeric_gradient(test, [x1, x2], numeric_eps=1e-4, check_eps=1e-1)

if __name__ == '__main__':
test_expand_dims()
Expand Down Expand Up @@ -1478,3 +1487,4 @@ def test_support_vector_machine_l2_svm():
test_correlation()
test_support_vector_machine_l1_svm()
test_support_vector_machine_l2_svm()
test_roipooling()

0 comments on commit 3a05d99

Please sign in to comment.