Skip to content

Commit

Permalink
Optimize transposed concat (#3368)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Aug 21, 2024
1 parent 7ab413f commit 03c43e5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/targets/gpu/jit/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,39 @@ struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"fused_concat", "concat"}; }

static std::vector<shape> normalize(std::vector<shape> inputs, std::size_t& axis)
{
auto s = inputs.back();
std::vector<std::size_t> strides(s.lens().size());
strides[axis] = 1;

inputs.push_back(shape{s.type(), s.lens(), strides});

auto result = reduce_dims(normalize_permutation(inputs));
auto rstrides = result.back().strides();
auto it = std::find_if(rstrides.begin(), rstrides.end(), [](auto x) { return x == 1; });
axis = it - rstrides.begin();
result.pop_back();
return result;
}

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.emplace_param("-Wno-float-equal");
auto concat_axis = v.at("axis").to<std::size_t>();
options.virtual_inputs = normalize(inputs, concat_axis);
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto axis = find_fast_axis(options.virtual_inputs);
auto op_names = v.at("ops").to_vector<std::string>();
auto args = v.at("args");
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size();
if(axis != concat_axis)
vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto nelements_per_op = options.virtual_inputs.back().elements() / op_names.size();
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
options.emplace_param("-Wno-float-equal");
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
for(auto i : range(op_names.size()))
Expand All @@ -105,7 +123,7 @@ struct concat_compiler : compiler<concat_compiler>
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
{"axis", std::to_string(concat_axis)}});
return compile_hip_code_object(src, options);
}

Expand Down
48 changes: 48 additions & 0 deletions test/verify/test_concat_nhwc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

template <migraphx::shape::type_t DType>
struct test_concat_nhwc : verify_program<test_concat_nhwc<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{DType, {2, 64, 56, 56}, {200704, 1, 3584, 64}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s0);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
return p;
}
};

template struct test_concat_nhwc<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_concat_nhwc<migraphx::shape::half_type>;
template struct test_concat_nhwc<migraphx::shape::float_type>;
template struct test_concat_nhwc<migraphx::shape::int32_type>;

0 comments on commit 03c43e5

Please sign in to comment.