forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorFactories.cpp
29 lines (22 loc) · 1.01 KB
/
TensorFactories.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <ATen/native/mkldnn/MKLDNNCommon.h>
namespace at { namespace native {
#if AT_MKLDNN_ENABLED()
Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
!options.has_memory_format(),
"'memory_format' argument is incompatible with mkldnn tensor");
TORCH_CHECK(
!optional_memory_format.has_value(),
"'memory_format' argument is incompatible with mkldnn tensor");
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
ideep::tensor::dims dst_dims (sizes.begin(), sizes.end());
ideep::tensor it {dst_dims, ideep::tensor::data_type::f32};
return new_with_itensor_mkldnn(std::move(it), options);
}
#else
Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
AT_ERROR("empty_mkldnn: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
}}