Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add contrib groupnorm #3678

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/onnx/parse_groupnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,72 @@

struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
std::vector<op_desc> operators() const
{
return {{"GroupNormalization", "GroupNormalization"}, {"GroupNorm", "Contrib_GroupNorm"}};
}

instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
bool is_contrib = (!opd.op_name.compare("Contrib_GroupNorm"));

Check warning on line 45 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'not' instead of ! [UseNamedLogicOperator]

Check warning on line 45 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / tidy

implicit conversion 'int' -> 'bool' [readability-implicit-bool-conversion,-warnings-as-errors]

Check warning on line 45 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / tidy

do not use 'compare' to test equality of strings; use the string equality operator instead [readability-string-compare,-warnings-as-errors]

float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
if(contains(info.attributes, "num_groups") or contains(info.attributes, "groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
if (is_contrib)

Check warning on line 55 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: inconclusive: Found duplicate branches for 'if' and 'else'. [duplicateBranch]
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}

bool is_nhwc = false;
if(is_contrib)
{ // default state for GroupNorm Contrib op
is_nhwc = true;
if(contains(info.attributes, "channels_last") and is_contrib)

Check warning on line 71 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Condition 'is_contrib' is always true [knownConditionTrueFalse]

Check warning on line 71 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / tidy

redundant condition 'is_contrib' [bugprone-redundant-branch-condition,-warnings-as-errors]
{
is_nhwc = parser.parse_value(info.attributes.at("channels_last")).at<size_t>();

Check warning on line 73 in src/onnx/parse_groupnorm.cpp

View workflow job for this annotation

GitHub Actions / tidy

implicit conversion 'unsigned long' -> 'bool' [readability-implicit-bool-conversion,-warnings-as-errors]
}
}

bool silu_activation = false;
if(contains(info.attributes, "activation") and is_contrib)
{
silu_activation = (1 == parser.parse_value(info.attributes.at("activation")).at<size_t>());
}
else if(is_contrib)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: activation must be available");
}

if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}

auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
// Adjust chanels from NHWC-> NCHW if last channel is set for contrib op
auto x = args.at(0);
if(is_nhwc and is_contrib)
{
x = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
}

auto scale = args.at(1); //gamma in the GroupNorm contrib case
auto bias = args.at(2); //beta in the GroupNorm contrib case

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
Expand Down Expand Up @@ -120,7 +156,20 @@
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
return info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
auto output = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);

if(silu_activation)
{
// SiLU activation is just out = x * sigmoid(x)
auto sigmoid = info.add_instruction(make_op("sigmoid"), output);
output = info.add_instruction(make_op("mul"), output, sigmoid);
}
// Convert to NCHW -> NHWC for contrib GroupNorm
if(is_nhwc and is_contrib)
{
output = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), output);
}
return output;
}
};

Expand Down
72 changes: 72 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4742,6 +4742,78 @@ def group_norm_invalid_bias_shape_test():
return group_norm_test([1, 4, 3, 3], [2], [3], [1, 4, 3, 3], 2)


def group_norm_contrib_test(x_dims,
gamma_dims,
beta_dims,
y_dims,
num_groups,
activation=0,
channels_last=0,
eps_value=1e-5,
dtype=TensorProto.FLOAT):
x = helper.make_tensor_value_info('x', dtype, x_dims)
gamma = helper.make_tensor_value_info('gamma', dtype, gamma_dims)
beta = helper.make_tensor_value_info('beta', dtype, beta_dims)
y = helper.make_tensor_value_info('y', dtype, y_dims)

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'beta'],
outputs=['y'],
activation=activation,
channels_last=channels_last,
num_groups=num_groups,
epsilon=eps_value)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_norm_contrib_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 0)


@onnx_test()
def group_norm_contrib_silu_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 1, 0)


@onnx_test()
def group_norm_contrib_channels_last_3d_test():
return group_norm_contrib_test([1, 4, 2], [2], [2], [1, 4, 2], 2, 0, 1)


@onnx_test()
def group_norm_contrib_no_activation_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2])
gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT, [2])
beta = helper.make_tensor_value_info('beta', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4, 2])

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'Beta'],
outputs=['y'],
channels_last=0,
num_groups=2)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_norm_contrib_no_num_groups_attr_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4, 2])
gamma = helper.make_tensor_value_info('gamma', TensorProto.FLOAT, [2])
beta = helper.make_tensor_value_info('beta', TensorProto.FLOAT, [2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4, 2])

node = onnx.helper.make_node('GroupNorm',
inputs=['x', 'gamma', 'Beta'],
outputs=['y'],
activation=0,
channels_last=0)

return ([node], [x, gamma, beta], [y])


@onnx_test()
def group_query_attention_test():
qkv = helper.make_tensor_value_info('qkv', TensorProto.FLOAT16,
Expand Down
Binary file added test/onnx/group_norm_contrib_3d_test.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/onnx/group_norm_contrib_silu_3d_test.onnx
Binary file not shown.
8 changes: 5 additions & 3 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@
const std::vector<int64_t>& reshape_dims,
const std::vector<int64_t>& reduce_axes,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
const migraphx::shape::type_t dtype = migraphx::shape::float_type,
const std::string param1_name = "scale",

Check warning on line 124 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param1_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 124 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param1_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 124 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param1_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]
const std::string param2_name = "bias")

Check warning on line 125 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param2_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 125 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param2_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]

Check warning on line 125 in test/onnx/include/onnx_test_utils.hpp

View workflow job for this annotation

GitHub Actions / tidy

the const qualified parameter 'param2_name' is copied for each invocation; consider making it a reference [performance-unnecessary-value-param,-warnings-as-errors]
{
migraphx::program p;
auto* mm = p.get_main_module();

auto x = mm->add_parameter("x", {dtype, input_dims});
auto scale = mm->add_parameter("scale", {dtype, scale_dims});
auto bias = mm->add_parameter("bias", {dtype, bias_dims});
auto scale = mm->add_parameter(param1_name, {dtype, scale_dims});
auto bias = mm->add_parameter(param2_name, {dtype, bias_dims});

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_no_activation_err_test)
{
EXPECT(test::throws([&] { read_onnx("group_norm_contrib_no_activation_attr_test.onnx"); }));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_no_num_groups_err_test)
{
EXPECT(test::throws([&] { read_onnx("group_norm_contrib_no_num_groups_attr_test.onnx"); }));
}
34 changes: 34 additions & 0 deletions test/onnx/parse/group_norm_contrib_3d_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_3d_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type, "gamma", "beta");
auto prog = optimize_onnx("group_norm_contrib_3d_test.onnx");
EXPECT(p == prog);
}
35 changes: 35 additions & 0 deletions test/onnx/parse/group_norm_contrib_channels_last_3d_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_channels_last_3d_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type, "gamma", "beta");

auto prog = optimize_onnx("group_norm_contrib_channels_last_3d_test.onnx");
EXPECT(p == prog);
}
35 changes: 35 additions & 0 deletions test/onnx/parse/group_norm_contrib_silu_3d_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(group_norm_contrib_silu_3d_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type, "gamma", "beta");

auto prog = optimize_onnx("group_norm_contrib_silu_3d_test.onnx");
EXPECT(p == prog);
}
Loading