From d08d87f5a0e701dbc46d396395de0968ff98f71c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Mon, 15 Aug 2016 01:44:18 -0700 Subject: [PATCH] broadcast_mask for rnn (#3016) --- mshadow | 2 +- src/operator/broadcast_mask_op-inl.h | 95 ++++++++++++++++++++++++++++ src/operator/broadcast_mask_op.cc | 8 +++ src/operator/broadcast_mask_op.cu | 8 +++ 4 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/operator/broadcast_mask_op-inl.h create mode 100644 src/operator/broadcast_mask_op.cc create mode 100644 src/operator/broadcast_mask_op.cu diff --git a/mshadow b/mshadow index 787ee960bbbe..db4c01523e8d 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 787ee960bbbed517cabbf41fdd8becf8de372203 +Subproject commit db4c01523e8d95277eae3bb52eb12260b46d6e03 diff --git a/src/operator/broadcast_mask_op-inl.h b/src/operator/broadcast_mask_op-inl.h new file mode 100644 index 000000000000..8f012922e1da --- /dev/null +++ b/src/operator/broadcast_mask_op-inl.h @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_mask_op-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_BROADCAST_MASK_OP_INL_H_ +#define MXNET_OPERATOR_BROADCAST_MASK_OP_INL_H_ + +#include +#include "./operator_common.h" + + +#if defined(__CUDACC__) +#define XPU gpu +#else +#define XPU cpu +#endif + +namespace mxnet { +namespace op { + +inline TShape ElementwiseMaskShape_(const TShape& lhs, + const TShape& rhs, + const EnvArguments& env) { + CHECK(lhs.ndim() > 1 && rhs.ndim() == 1) << + "source tensor should be 2D or more and mask should be 1D"; + CHECK_EQ(lhs[0], rhs[0]) << "The first dimention of inputs should be same"; + return TShape(lhs); +} + +template +void ElementwiseMaskForward_(const TBlob& lhs, + const TBlob& rhs, + const EnvArguments& env, + TBlob *ret, + OpReqType req, + RunContext ctx) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(ret->type_flag_, lhs.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(ret->type_flag_, rhs.type_flag_) + << "Binary function only support input/output with the same type"; + CHECK(lhs.shape_.ndim() > 1 && rhs.shape_.ndim() == 1 && + lhs.shape_[0] == rhs.shape_[0]) << + "the first ndim of lhs and rhs must be equal, lhs should be 2D or more and rhs shoube be 1D" + " shape of lhs=" << lhs.shape_ << " shape of rhs=" << rhs.shape_; + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { + mshadow::Tensor out = ret->FlatTo2D(s); + ASSIGN_DISPATCH(out, req, + // TODO(bing): swap because requirement of inplace, change mshadow later + mask(rhs.get(s), lhs.FlatTo2D(s))); + }); + return; +} + +template +void ElementwiseMaskBackward_(const OutputGrad& out_grad, + const Input0& lhs, + const Input1& rhs, + const EnvArguments& env, + TBlob* lhs_grad, + TBlob* rhs_grad, + OpReqType req_lhs_grad, + OpReqType req_rhs_grad, + RunContext ctx) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(lhs_grad->type_flag_, DType, { + mshadow::Tensor mout_grad = out_grad.data.FlatTo2D(s); + mshadow::Tensor mlhs_grad = lhs_grad->FlatTo2D(s); + mshadow::Tensor mrhs_data = rhs.data.get(s); + ASSIGN_DISPATCH(mlhs_grad, req_lhs_grad, + // TODO(bing): swap because requirement of inplace, change mshadow later + mask(mrhs_data, mout_grad)); + }); + return; +} + + +MXNET_REGISTER_SIMPLE_OP(element_mask, XPU) +.set_shape_function(ElementwiseMaskShape_) +.set_function(XPU::kDevMask, ElementwiseMaskForward_, kInplaceLhsOut, kRegisterSymbolic) +.set_gradient(XPU::kDevMask, ElementwiseMaskBackward_, kInplaceOutLhs) +.describe("rhs elmentwise mask lhs with broadcast"); + +} // namespace op +} // namespace mxnet + + +#endif // MXNET_OPERATOR_BROADCAST_MASK_OP_INL_H_ + diff --git a/src/operator/broadcast_mask_op.cc b/src/operator/broadcast_mask_op.cc new file mode 100644 index 000000000000..a32f57e81be7 --- /dev/null +++ b/src/operator/broadcast_mask_op.cc @@ -0,0 +1,8 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_mask_op.cc + * \brief + * \author Bing Xu +*/ +#include "./broadcast_mask_op-inl.h" + diff --git a/src/operator/broadcast_mask_op.cu b/src/operator/broadcast_mask_op.cu new file mode 100644 index 000000000000..822458687452 --- /dev/null +++ b/src/operator/broadcast_mask_op.cu @@ -0,0 +1,8 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_mask_op.cu + * \brief + * \author Bing Xu +*/ +#include "./broadcast_mask_op-inl.h" +