forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
meta_tensor.cpp
35 lines (29 loc) · 1.15 KB
/
meta_tensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <gtest/gtest.h>
#include <ATen/MetaFunctions.h>
#include <torch/torch.h>
#include <vector>
TEST(MetaTensorTest, MetaDeviceApi) {
auto a = at::ones({4}, at::kFloat);
auto b = at::ones({3, 4}, at::kFloat);
// at::add() will return a meta tensor if its inputs are also meta tensors.
auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
ASSERT_EQ(a.device(), c10::kCPU);
ASSERT_EQ(b.device(), c10::kCPU);
ASSERT_EQ(out_meta.device(), c10::kMeta);
c10::IntArrayRef sizes_actual = out_meta.sizes();
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
ASSERT_EQ(sizes_actual, sizes_expected);
}
TEST(MetaTensorTest, MetaNamespaceApi) {
auto a = at::ones({4}, at::kFloat);
auto b = at::ones({3, 4}, at::kFloat);
// The at::meta:: namespace take in tensors from any backend
// and return a meta tensor.
auto out_meta = at::meta::add(a, b);
ASSERT_EQ(a.device(), c10::kCPU);
ASSERT_EQ(b.device(), c10::kCPU);
ASSERT_EQ(out_meta.device(), c10::kMeta);
c10::IntArrayRef sizes_actual = out_meta.sizes();
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
ASSERT_EQ(sizes_actual, sizes_expected);
}