-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ref and gpu implementations for ONNX op GatherND Resolves #1032
- Loading branch information
Showing
17 changed files
with
671 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,7 @@ register_migraphx_ops( | |
flatten | ||
floor | ||
gather | ||
gathernd | ||
get_tuple_elem | ||
greater | ||
gru | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP | ||
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP | ||
|
||
#include <migraphx/check_shapes.hpp> | ||
#include <migraphx/shape_for_each.hpp> | ||
#include <migraphx/par_for.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace op { | ||
|
||
struct gathernd | ||
{ | ||
int batch_dims = 0; | ||
|
||
template <class Self, class F> | ||
static auto reflect(Self& self, F f) | ||
{ | ||
return pack(f(self.batch_dims, "batch_dims")); | ||
} | ||
|
||
std::string name() const { return "gathernd"; } | ||
|
||
shape compute_shape(std::vector<shape> inputs) const | ||
{ | ||
check_shapes{inputs, *this}.has(2); | ||
auto r = inputs.front().lens().size(); | ||
auto q = inputs.back().lens().size(); | ||
auto k = inputs.back().lens().back(); | ||
if(k > r - batch_dims) | ||
{ | ||
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + | ||
" cannot be used to access data of rank " + | ||
std::to_string(r - batch_dims)); | ||
} | ||
auto indices_lens_iter = inputs.back().lens().begin(); | ||
auto output_lens_size = q + r - k - batch_dims - 1; | ||
std::vector<std::size_t> output_lens(output_lens_size); | ||
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); | ||
if(k < r - batch_dims) | ||
{ | ||
auto data_lens = inputs.front().lens(); | ||
std::copy( | ||
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); | ||
} | ||
shape output_shape{inputs.front().type(), output_lens}; | ||
return output_shape; | ||
} | ||
|
||
argument compute(const shape& output_shape, std::vector<argument> args) const | ||
{ | ||
argument result{output_shape}; | ||
visit_all(result, args[0])([&](auto output, auto data) { | ||
args[1].visit([&](auto indices) { | ||
auto indices_shape = indices.get_shape(); | ||
auto indices_shape_lens = indices_shape.lens(); | ||
auto data_shape = data.get_shape(); | ||
auto data_shape_lens = data_shape.lens(); | ||
auto k = indices_shape.lens().back(); | ||
const auto num_slice_dims = k; | ||
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(), | ||
indices_shape_lens.end() - 1, | ||
1, | ||
std::multiplies<std::size_t>()); | ||
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims, | ||
data_shape_lens.end(), | ||
1, | ||
std::multiplies<std::size_t>()); | ||
std::size_t num_batches = std::accumulate(data_shape_lens.begin(), | ||
data_shape_lens.begin() + batch_dims, | ||
1, | ||
std::multiplies<std::size_t>()); | ||
std::size_t data_batch_stride = | ||
std::accumulate(data_shape_lens.begin() + batch_dims, | ||
data_shape_lens.end(), | ||
1, | ||
std::multiplies<std::size_t>()); | ||
auto num_slices_per_batch = num_slices / num_batches; | ||
|
||
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims); | ||
{ | ||
auto running_product = slice_size; | ||
for(std::size_t i = 0; i < num_slice_dims; ++i) | ||
{ | ||
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product; | ||
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i]; | ||
} | ||
} | ||
|
||
std::vector<std::size_t> input_slice_offsets(num_slices); | ||
par_for(num_slices, [&](const auto i) { | ||
std::size_t batch_idx = i / num_slices_per_batch; | ||
|
||
auto slice_indices = indices.begin() + (i * num_slice_dims); | ||
std::size_t relative_slice_offset = 0; | ||
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) | ||
{ | ||
int64_t index = *(slice_indices + dim_idx); | ||
const std::size_t input_dim_idx = batch_dims + dim_idx; | ||
const auto input_dim = data_shape_lens[input_dim_idx]; | ||
if(index < -static_cast<int64_t>(input_dim) or | ||
index >= static_cast<int64_t>(input_dim)) | ||
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) + | ||
" is out of bounds for dim of len " + | ||
std::to_string(input_dim)); | ||
if(index < 0) | ||
index += input_dim; | ||
|
||
relative_slice_offset += index * sizes_from_slice_dims[dim_idx]; | ||
} | ||
|
||
input_slice_offsets[i] = | ||
(batch_idx * data_batch_stride) + relative_slice_offset; | ||
}); | ||
|
||
par_for(num_slices * slice_size, [&](const auto i) { | ||
auto slice_offset = input_slice_offsets[i / slice_size]; | ||
output[i] = data[slice_offset + i % slice_size]; | ||
}); | ||
}); | ||
}); | ||
|
||
return result; | ||
} | ||
}; | ||
|
||
} // namespace op | ||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#include <migraphx/gpu/compiler.hpp> | ||
#include <migraphx/make_op.hpp> | ||
#include <migraphx/gpu/context.hpp> | ||
|
||
#include <migraphx/gpu/compile_hip_code_object.hpp> | ||
#include <migraphx/gpu/compile_hip.hpp> | ||
#include <migraphx/ranges.hpp> | ||
#include <migraphx/reduce_dims.hpp> | ||
#include <migraphx/stringutils.hpp> | ||
#include <migraphx/dead_code_elimination.hpp> | ||
#include <migraphx/eliminate_common_subexpression.hpp> | ||
#include <migraphx/module.hpp> | ||
#include <migraphx/pass_manager.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace gpu { | ||
|
||
// NOLINTNEXTLINE | ||
static const char* const gathernd_kernel = R"__migraphx__( | ||
#include <migraphx/kernels/gathernd.hpp> | ||
#include <migraphx/kernels/basic_ops.hpp> | ||
#include <migraphx/kernels/integral_constant.hpp> | ||
#include <migraphx/kernels/generic_constant.hpp> | ||
#include <args.hpp> | ||
namespace migraphx { | ||
extern "C" { | ||
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output) | ||
{ | ||
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { | ||
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS})); | ||
gathernd(xs..., settings); | ||
}); | ||
} | ||
} | ||
} // namespace migraphx | ||
)__migraphx__"; | ||
|
||
struct gathernd_compiler : compiler<gathernd_compiler> | ||
{ | ||
std::vector<std::string> names() const { return {"gathernd"}; } | ||
|
||
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const | ||
{ | ||
hip_compile_options options; | ||
auto out_s = inputs.back(); | ||
options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); | ||
options.inputs = inputs; | ||
options.output = out_s; | ||
options.kernel_name = "gathernd_kernel"; | ||
options.virtual_inputs = inputs; | ||
|
||
// batch_dims | ||
assert(v.contains("batch_dims")); | ||
auto batch_dims = v.at("batch_dims").to<int64_t>(); | ||
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims); | ||
|
||
return compile_hip_code_object(gathernd_kernel, options); | ||
} | ||
|
||
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const | ||
{ | ||
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); | ||
} | ||
}; | ||
|
||
} // namespace gpu | ||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP | ||
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP | ||
|
||
#include <migraphx/kernels/index.hpp> | ||
#include <migraphx/kernels/algorithm.hpp> | ||
|
||
namespace migraphx { | ||
|
||
template <class T> | ||
struct gathernd_settings | ||
{ | ||
T batch_dims{}; | ||
}; | ||
|
||
template <class... Ts> | ||
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs) | ||
{ | ||
return {xs...}; | ||
} | ||
|
||
template <class T, class U, class V, class Settings> | ||
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s) | ||
{ | ||
auto ind = make_index(); | ||
auto batch_dims = s.batch_dims; | ||
auto output_shape = output_t.get_shape(); | ||
auto indices_shape = indices_t.get_shape(); | ||
auto data_shape = data_t.get_shape(); | ||
|
||
auto indices_shape_lens = indices_shape.lens; | ||
auto data_shape_lens = data_shape.lens; | ||
auto num_slice_dims = indices_shape_lens.back(); | ||
std::size_t num_slices = accumulate(indices_shape_lens.begin(), | ||
indices_shape_lens.end() - 1, | ||
1, | ||
std::multiplies<std::size_t>()); | ||
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, | ||
data_shape_lens.end(), | ||
1, | ||
std::multiplies<std::size_t>()); | ||
const std::size_t num_batches = accumulate(data_shape_lens.begin(), | ||
data_shape_lens.begin() + batch_dims, | ||
1, | ||
std::multiplies<std::size_t>()); | ||
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims, | ||
data_shape_lens.end(), | ||
1, | ||
std::multiplies<std::size_t>()); | ||
const auto num_slices_per_batch = num_slices / num_batches; | ||
|
||
ind.global_stride(output_shape.elements(), [&](auto i) { | ||
const auto* indices_ptr = indices_t.data(); | ||
const std::size_t j = i / slice_size; | ||
const std::size_t batch_idx = j / num_slices_per_batch; | ||
|
||
auto* slice_indices = indices_ptr + (j * num_slice_dims); | ||
std::size_t relative_slice_offset = 0; | ||
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) | ||
{ | ||
int64_t index = slice_indices[idx]; | ||
const std::size_t input_dim_idx = batch_dims + idx; | ||
const auto input_dim = data_shape_lens[input_dim_idx]; | ||
assert(index >= -static_cast<int64_t>(input_dim) and | ||
index < static_cast<int64_t>(input_dim)); | ||
if(index < 0) | ||
index += input_dim; | ||
std::size_t size_from_slice_dims = | ||
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, | ||
data_shape_lens.begin() + batch_dims + num_slice_dims, | ||
slice_size, | ||
std::multiplies<std::size_t>()); | ||
relative_slice_offset += index * size_from_slice_dims; | ||
} | ||
|
||
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset; | ||
output_t[i] = data_t[slice_offset + i % slice_size]; | ||
}); | ||
} | ||
|
||
} // namespace migraphx | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
gathernd_batch_dims_test:� | ||
/ | ||
data | ||
indicesy"GatherND* | ||
|
||
batch_dims�gathernd_batch_dims_testZ | ||
data | ||
Z | ||
indices | ||
b | ||
y | ||
B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
gathernd_test:q | ||
data | ||
indicesy"GatherNDgathernd_testZ | ||
data | ||
Z | ||
indices | ||
b | ||
y | ||
|
||
B | ||
|
Oops, something went wrong.