From 4ec35e5fbe84a43c14bc9cfbcb0a9896d87cbde9 Mon Sep 17 00:00:00 2001 From: turneram <71655887+turneram@users.noreply.github.com> Date: Thu, 28 Apr 2022 19:45:16 -0500 Subject: [PATCH] Add GatherND operator (#1089) Add ref and gpu implementations for ONNX op GatherND Resolves #1032 --- src/CMakeLists.txt | 1 + src/include/migraphx/op/gathernd.hpp | 131 ++++++++++++ src/include/migraphx/operators.hpp | 1 + src/onnx/parse_generic_op.cpp | 1 + src/targets/gpu/jit/gathernd.cpp | 75 +++++++ .../include/migraphx/kernels/algorithm.hpp | 10 + .../include/migraphx/kernels/gathernd.hpp | 81 +++++++ test/onnx/gathernd_batch_dims_test.onnx | 19 ++ test/onnx/gathernd_test.onnx | 16 ++ test/onnx/gen_onnx.py | 29 +++ test/onnx/onnx_test.cpp | 25 +++ test/py/onnx_backend_test.py | 3 - test/ref_ops_test.cpp | 197 ++++++++++++++++++ test/verify/test_gathernd_batch_dims_1.cpp | 22 ++ test/verify/test_gathernd_batch_dims_2.cpp | 21 ++ test/verify/test_gathernd_default.cpp | 20 ++ .../verify/test_gathernd_negative_indices.cpp | 22 ++ 17 files changed, 671 insertions(+), 3 deletions(-) create mode 100644 src/include/migraphx/op/gathernd.hpp create mode 100644 src/targets/gpu/jit/gathernd.cpp create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp create mode 100644 test/onnx/gathernd_batch_dims_test.onnx create mode 100644 test/onnx/gathernd_test.onnx create mode 100644 test/verify/test_gathernd_batch_dims_1.cpp create mode 100644 test/verify/test_gathernd_batch_dims_2.cpp create mode 100644 test/verify/test_gathernd_default.cpp create mode 100644 test/verify/test_gathernd_negative_indices.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1463e9698f6..1b32e3d48ed 100755 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,6 +109,7 @@ register_migraphx_ops( flatten floor gather + gathernd get_tuple_elem greater gru diff --git a/src/include/migraphx/op/gathernd.hpp b/src/include/migraphx/op/gathernd.hpp new file mode 100644 index 00000000000..2b954787836 --- /dev/null +++ b/src/include/migraphx/op/gathernd.hpp @@ -0,0 +1,131 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP +#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct gathernd +{ + int batch_dims = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.batch_dims, "batch_dims")); + } + + std::string name() const { return "gathernd"; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + auto r = inputs.front().lens().size(); + auto q = inputs.back().lens().size(); + auto k = inputs.back().lens().back(); + if(k > r - batch_dims) + { + MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + + " cannot be used to access data of rank " + + std::to_string(r - batch_dims)); + } + auto indices_lens_iter = inputs.back().lens().begin(); + auto output_lens_size = q + r - k - batch_dims - 1; + std::vector output_lens(output_lens_size); + std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); + if(k < r - batch_dims) + { + auto data_lens = inputs.front().lens(); + std::copy( + data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); + } + shape output_shape{inputs.front().type(), output_lens}; + return output_shape; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[0])([&](auto output, auto data) { + args[1].visit([&](auto indices) { + auto indices_shape = indices.get_shape(); + auto indices_shape_lens = indices_shape.lens(); + auto data_shape = data.get_shape(); + auto data_shape_lens = data_shape.lens(); + auto k = indices_shape.lens().back(); + const auto num_slice_dims = k; + std::size_t num_slices = std::accumulate(indices_shape_lens.begin(), + indices_shape_lens.end() - 1, + 1, + std::multiplies()); + std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + std::size_t num_batches = std::accumulate(data_shape_lens.begin(), + data_shape_lens.begin() + batch_dims, + 1, + std::multiplies()); + std::size_t data_batch_stride = + std::accumulate(data_shape_lens.begin() + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + auto num_slices_per_batch = num_slices / num_batches; + + std::vector sizes_from_slice_dims(num_slice_dims); + { + auto running_product = slice_size; + for(std::size_t i = 0; i < num_slice_dims; ++i) + { + sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product; + running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i]; + } + } + + std::vector input_slice_offsets(num_slices); + par_for(num_slices, [&](const auto i) { + std::size_t batch_idx = i / num_slices_per_batch; + + auto slice_indices = indices.begin() + (i * num_slice_dims); + std::size_t relative_slice_offset = 0; + for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) + { + int64_t index = *(slice_indices + dim_idx); + const std::size_t input_dim_idx = batch_dims + dim_idx; + const auto input_dim = data_shape_lens[input_dim_idx]; + if(index < -static_cast(input_dim) or + index >= static_cast(input_dim)) + MIGRAPHX_THROW("GatherND: index " + std::to_string(index) + + " is out of bounds for dim of len " + + std::to_string(input_dim)); + if(index < 0) + index += input_dim; + + relative_slice_offset += index * sizes_from_slice_dims[dim_idx]; + } + + input_slice_offsets[i] = + (batch_idx * data_batch_stride) + relative_slice_offset; + }); + + par_for(num_slices * slice_size, [&](const auto i) { + auto slice_offset = input_slice_offsets[i / slice_size]; + output[i] = data[slice_offset + i % slice_size]; + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index b5b615a0083..04f7a63fca7 100755 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include diff --git a/src/onnx/parse_generic_op.cpp b/src/onnx/parse_generic_op.cpp index 1932ee74e10..0bfb4296f7c 100644 --- a/src/onnx/parse_generic_op.cpp +++ b/src/onnx/parse_generic_op.cpp @@ -28,6 +28,7 @@ struct parse_generic_op : op_parser {"Flatten", "flatten"}, {"Floor", "floor"}, {"Gather", "gather"}, + {"GatherND", "gathernd"}, {"Identity", "identity"}, {"IsNaN", "isnan"}, {"LeakyRelu", "leaky_relu"}, diff --git a/src/targets/gpu/jit/gathernd.cpp b/src/targets/gpu/jit/gathernd.cpp new file mode 100644 index 00000000000..0ca14186dec --- /dev/null +++ b/src/targets/gpu/jit/gathernd.cpp @@ -0,0 +1,75 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +// NOLINTNEXTLINE +static const char* const gathernd_kernel = R"__migraphx__( +#include +#include +#include +#include +#include + +namespace migraphx { + +extern "C" { + +__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output) +{ + make_tensors()(in_data, in_indices, output)([](auto&&... xs) { + auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS})); + gathernd(xs..., settings); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct gathernd_compiler : compiler +{ + std::vector names() const { return {"gathernd"}; } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + auto out_s = inputs.back(); + options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); + options.inputs = inputs; + options.output = out_s; + options.kernel_name = "gathernd_kernel"; + options.virtual_inputs = inputs; + + // batch_dims + assert(v.contains("batch_dims")); + auto batch_dims = v.at("batch_dims").to(); + options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims); + + return compile_hip_code_object(gathernd_kernel, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp index eebea90e9ad..2b702c05612 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp @@ -21,6 +21,16 @@ struct greater } }; +template +constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) +{ + for(; first != last; ++first) + { + init = op(std::move(init), *first); + } + return init; +} + template constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp new file mode 100644 index 00000000000..22d49ac3811 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp @@ -0,0 +1,81 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP +#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP + +#include +#include + +namespace migraphx { + +template +struct gathernd_settings +{ + T batch_dims{}; +}; + +template +constexpr gathernd_settings make_gathernd_settings(Ts... xs) +{ + return {xs...}; +} + +template +__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s) +{ + auto ind = make_index(); + auto batch_dims = s.batch_dims; + auto output_shape = output_t.get_shape(); + auto indices_shape = indices_t.get_shape(); + auto data_shape = data_t.get_shape(); + + auto indices_shape_lens = indices_shape.lens; + auto data_shape_lens = data_shape.lens; + auto num_slice_dims = indices_shape_lens.back(); + std::size_t num_slices = accumulate(indices_shape_lens.begin(), + indices_shape_lens.end() - 1, + 1, + std::multiplies()); + std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + const std::size_t num_batches = accumulate(data_shape_lens.begin(), + data_shape_lens.begin() + batch_dims, + 1, + std::multiplies()); + const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + const auto num_slices_per_batch = num_slices / num_batches; + + ind.global_stride(output_shape.elements(), [&](auto i) { + const auto* indices_ptr = indices_t.data(); + const std::size_t j = i / slice_size; + const std::size_t batch_idx = j / num_slices_per_batch; + + auto* slice_indices = indices_ptr + (j * num_slice_dims); + std::size_t relative_slice_offset = 0; + for(std::size_t idx = 0; idx < num_slice_dims; ++idx) + { + int64_t index = slice_indices[idx]; + const std::size_t input_dim_idx = batch_dims + idx; + const auto input_dim = data_shape_lens[input_dim_idx]; + assert(index >= -static_cast(input_dim) and + index < static_cast(input_dim)); + if(index < 0) + index += input_dim; + std::size_t size_from_slice_dims = + accumulate(data_shape_lens.begin() + batch_dims + idx + 1, + data_shape_lens.begin() + batch_dims + num_slice_dims, + slice_size, + std::multiplies()); + relative_slice_offset += index * size_from_slice_dims; + } + + auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset; + output_t[i] = data_t[slice_offset + i % slice_size]; + }); +} + +} // namespace migraphx +#endif diff --git a/test/onnx/gathernd_batch_dims_test.onnx b/test/onnx/gathernd_batch_dims_test.onnx new file mode 100644 index 00000000000..1b488f13d06 --- /dev/null +++ b/test/onnx/gathernd_batch_dims_test.onnx @@ -0,0 +1,19 @@ +gathernd_batch_dims_test:— +/ +data +indicesy"GatherND* + +batch_dims gathernd_batch_dims_testZ +data + + + +Z +indices +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/gathernd_test.onnx b/test/onnx/gathernd_test.onnx new file mode 100644 index 00000000000..8d6afc78bee --- /dev/null +++ b/test/onnx/gathernd_test.onnx @@ -0,0 +1,16 @@ + gathernd_test:q + +data +indicesy"GatherND gathernd_testZ +data +  + +Z +indices +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 4521856a57c..ee6f1eba829 100755 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -1666,6 +1666,35 @@ def gather_elements_axis1_test(): return ([node], [x, i], [y]) +@onnx_test +def gathernd_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2]) + i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2]) + + node = onnx.helper.make_node('GatherND', + inputs=['data', 'indices'], + outputs=['y']) + + return ([node], [x, i], [y]) + + +@onnx_test +def gathernd_batch_dims_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) + i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node( + 'GatherND', + inputs=['data', 'indices'], + outputs=['y'], + batch_dims=1, + ) + + return ([node], [x, i], [y]) + + @onnx_test def gemm_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 59c3acf457a..fed40d5a14c 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -1582,6 +1582,31 @@ TEST_CASE(gather_elements_axis1_test) EXPECT(p == prog); } +TEST_CASE(gathernd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 2}}); + mm->add_instruction(migraphx::make_op("gathernd"), l0, l1); + auto prog = optimize_onnx("gathernd_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gathernd_batch_dims_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1}}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), l0, l1); + auto prog = optimize_onnx("gathernd_batch_dims_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(gemm_test) { migraphx::program p; diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 1e4cd76e85b..527cbf408ee 100755 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -268,9 +268,6 @@ def create_backend_test(testname=None, target_device=None): backend_test.exclude(r'test_expand_shape_model2_cpu') backend_test.exclude(r'test_expand_shape_model3_cpu') backend_test.exclude(r'test_expand_shape_model4_cpu') - backend_test.exclude(r'test_gathernd_example_float32_cpu') - backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu') - backend_test.exclude(r'test_gathernd_example_int32_cpu') backend_test.exclude(r'test_identity_sequence_cpu') backend_test.exclude(r'test_maxpool_2d_uint8_cpu') backend_test.exclude(r'test_negative_log_likelihood_loss_*') diff --git a/test/ref_ops_test.cpp b/test/ref_ops_test.cpp index 495e805295d..547043ccb23 100644 --- a/test/ref_ops_test.cpp +++ b/test/ref_ops_test.cpp @@ -1653,6 +1653,203 @@ TEST_CASE(gather_test) } } +TEST_CASE(gathernd_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 1, 1}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 3}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{1, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{2, 3, 0, 1}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}}; + + std::vector data_vec(2 * 3 * 1); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{1, 0, 0, 1}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2, 2}}; + + std::vector data_vec(2 * 3 * 2 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 0, 1, 0, 0, 0, 1}; + const int batch_dims = 1; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + + std::vector data_vec(2 * 3 * 1 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0}; + const int batch_dims = 2; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 4, 8, 11, 13, 15}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + // k > r - batch_dims + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 3}}; + + std::vector data_vec(2 * 3 * 1 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec(2 * 3 * 3, 0); + const int batch_dims = 2; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + EXPECT(test::throws([&] { + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + })); + } +} + +TEST_CASE(gathernd_negative_index_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{-1, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{2, 3, 0, 1}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{-3, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + + EXPECT(test::throws([&] { p.eval({}); })); + } +} + TEST_CASE(globalavgpool_test) { migraphx::program p; diff --git a/test/verify/test_gathernd_batch_dims_1.cpp b/test/verify/test_gathernd_batch_dims_1.cpp new file mode 100644 index 00000000000..c902f80c82f --- /dev/null +++ b/test/verify/test_gathernd_batch_dims_1.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_batch_dims_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + std::vector indices{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_batch_dims_2.cpp b/test/verify/test_gathernd_batch_dims_2.cpp new file mode 100644 index 00000000000..94b914293d7 --- /dev/null +++ b/test/verify/test_gathernd_batch_dims_2.cpp @@ -0,0 +1,21 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_batch_dims_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + std::vector indices{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 2; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_default.cpp b/test/verify/test_gathernd_default.cpp new file mode 100644 index 00000000000..020210e7c86 --- /dev/null +++ b/test/verify/test_gathernd_default.cpp @@ -0,0 +1,20 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_default : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; + std::vector indices{0, 0, 1, 1}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + mm->add_instruction(migraphx::make_op("gathernd"), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_negative_indices.cpp b/test/verify/test_gathernd_negative_indices.cpp new file mode 100644 index 00000000000..e1e1f2b2796 --- /dev/null +++ b/test/verify/test_gathernd_negative_indices.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_negative_indices : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + std::vector indices{-1, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +};