forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMKLDNNCommon.cpp
78 lines (67 loc) · 2.81 KB
/
MKLDNNCommon.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/OpaqueTensorImpl.h>
#include <c10/core/Allocator.h>
#if AT_MKLDNN_ENABLED()
#include <ideep.hpp>
namespace at { namespace native {
/**
* `IntrusivePtrTargetWrapper` wraps a custom storage handle of a tensor
* (as template param) and inherits `c10::intrusive_ptr_target` so that it
* can be used with `c10::intrusive_ptr`.
*
* It currently only supports wrapping the custom handle by:
* - Constructing with an existing custom handle by copy/move constructor.
*
* See `OpaqueTensorImpl::opaque_handle_`.
*
* NOTE: if this is generally useful we may want to move this to its own header.
*/
template <typename T>
struct CAFFE2_API IntrusivePtrTargetWrapper : c10::intrusive_ptr_target {
private:
T target_;
public:
IntrusivePtrTargetWrapper() = delete;
IntrusivePtrTargetWrapper(const T& target): target_(target) {}
IntrusivePtrTargetWrapper(T&& target): target_(std::move(target)) {}
T& get_target() {
return target_;
}
};
using IDeepTensorWrapper = IntrusivePtrTargetWrapper<ideep::tensor>;
using IDeepTensorWrapperPtr = c10::intrusive_ptr<IDeepTensorWrapper>;
using MKLDNNTensorImpl = OpaqueTensorImpl<IDeepTensorWrapperPtr>;
using MKLDNNTensor = Tensor;
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) {
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
auto dims = it.get_dims();
IDeepTensorWrapperPtr handle = c10::make_intrusive<IDeepTensorWrapper>(std::move(it));
return detail::make_tensor<MKLDNNTensorImpl>(
DispatchKeySet(DispatchKey::MkldnnCPU),
options.dtype(), options.device(), handle,
std::vector<int64_t>(dims.begin(), dims.end()));
}
ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
"itensor_from_mkldnn expects MKL-DNN tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
return mklimpl->unsafe_opaque_handle()->get_target();
}
ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
TORCH_CHECK(
tensor.device().type() == DeviceType::CPU,
"itensor_view_from_dense expects CPU tensor input");
TORCH_CHECK(
tensor.layout() == Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
TORCH_CHECK(tensor.scalar_type() == ScalarType::Float,
"itensor_view_from_dense expects float tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
return {{{tensor.sizes().cbegin(), tensor.sizes().cend()},
ideep::tensor::data_type::f32},
tensor.template data_ptr<float>()};
}
}}
#endif // AT_MKLDNN_ENABLED()