Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce export feature to TensorRT JSON format #3721

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_library(migraphx
insert_pad.cpp
instruction.cpp
json.cpp
trt_json.cpp
layout_convolution.cpp
lexing.cpp
load_save.cpp
Expand Down
40 changes: 40 additions & 0 deletions src/include/migraphx/trt_json.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_TRT_JSON_HPP
#define MIGRAPHX_GUARD_RTGLIB_TRT_JSON_HPP

#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <string>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

MIGRAPHX_EXPORT std::string to_trt_json_string(const migraphx::program& p,
std::optional<size_t> indent);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
6 changes: 6 additions & 0 deletions src/load_save.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/trt_json.hpp>
#include <migraphx/msgpack.hpp>
#include <fstream>

Expand Down Expand Up @@ -94,6 +95,11 @@
std::string s = to_json_string(v);
buffer = std::vector<char>(s.begin(), s.end());
}
else if(options.format == "trt.json")
{
std::string s = to_trt_json_string(p, 4);
buffer = std::vector<char>(s.begin(), s.end());

Check warning on line 101 in src/load_save.cpp

View check run for this annotation

Codecov / codecov/patch

src/load_save.cpp#L101

Added line #L101 was not covered by tests
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
Expand Down
233 changes: 233 additions & 0 deletions src/trt_json.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <nlohmann/json.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/trt_json.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using json = nlohmann::json;

static void value_to_json(const value& val, json& j);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

namespace nlohmann {
template <>
struct adl_serializer<migraphx::value>
{
static void to_json(json& j, const migraphx::value& val) { migraphx::value_to_json(val, j); }

Check warning on line 47 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L47

Added line #L47 was not covered by tests
};
} // namespace nlohmann

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using json = nlohmann::json;

template <class T>
static void value_to_json(const T& x, json& j)

Check warning on line 57 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L57

Added line #L57 was not covered by tests
{
j = x;
}

Check warning on line 60 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L59-L60

Added lines #L59 - L60 were not covered by tests

static void value_to_json(const value::binary& x, json& j)

Check warning on line 62 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L62

Added line #L62 was not covered by tests
{
j = json::object();
j["bytes"] = std::vector<int>(x.begin(), x.end());
}

Check warning on line 66 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L64-L66

Added lines #L64 - L66 were not covered by tests

static void value_to_json(const std::vector<value>& x, json& j)

Check warning on line 68 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L68

Added line #L68 was not covered by tests
{
for(const auto& v : x)

Check warning on line 70 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L70

Added line #L70 was not covered by tests
{
if(v.get_key().empty())

Check warning on line 72 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L72

Added line #L72 was not covered by tests
{
j.push_back(v);

Check warning on line 74 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L74

Added line #L74 was not covered by tests
}
else
{
j[v.get_key()] = v.without_key();

Check warning on line 78 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L78

Added line #L78 was not covered by tests
}
}
}

Check warning on line 81 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L81

Added line #L81 was not covered by tests

static void value_to_json(std::nullptr_t&, json& j) { j = {}; }

Check warning on line 83 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L83

Added line #L83 was not covered by tests

static void value_to_json(const value& val, json& j)

Check warning on line 85 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L85

Added line #L85 was not covered by tests
{
if(val.is_array())

Check warning on line 87 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L87

Added line #L87 was not covered by tests
{
j = json::array();

Check warning on line 89 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L89

Added line #L89 was not covered by tests
}

if(val.is_object())

Check warning on line 92 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L92

Added line #L92 was not covered by tests
{
j = json::object();

Check warning on line 94 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L94

Added line #L94 was not covered by tests
}

val.visit([&](auto v) { value_to_json(v, j); });
}

Check warning on line 98 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L97-L98

Added lines #L97 - L98 were not covered by tests

std::string type_to_trt_type_string(shape::type_t type)

Check warning on line 100 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L100

Added line #L100 was not covered by tests
{
switch(type)

Check warning on line 102 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / tidy

4 enumeration values not explicitly handled in switch: 'fp8e4m3fnuz_type', 'fp8e4m3fn_type', 'fp8e5m2_type'... [clang-diagnostic-switch-enum,-warnings-as-errors]

Check warning on line 102 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L102

Added line #L102 was not covered by tests
{
case shape::bool_type: return "BOOL";
case shape::half_type: return "FP16";
case shape::float_type: return "FP32";
case shape::double_type: return "FP64";
case shape::uint8_type: return "UINT8";
case shape::int8_type: return "INT8";
case shape::uint16_type: return "UINT16";
case shape::int16_type: return "INT16";
case shape::uint32_type: return "UINT32";
case shape::int32_type: return "INT32";
case shape::uint64_type: return "UINT64";
case shape::int64_type: return "INT64";
case shape::bf16_type: return "BF16";
case shape::tuple_type: return "TUPLE";

Check warning on line 117 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L104-L117

Added lines #L104 - L117 were not covered by tests
// TODO fp8 types
default: MIGRAPHX_THROW("Unsupported type: " + shape::name(type));

Check warning on line 119 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L119

Added line #L119 was not covered by tests
}
}

std::string to_trt_json_string(const program& p, std::optional<size_t> indent = std::nullopt)

Check warning on line 123 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L123

Added line #L123 was not covered by tests
{
std::unordered_map<instruction_ref, std::string> ins_to_name;
std::unordered_map<std::string, unsigned int> name_count;

auto* mm = p.get_main_module();

Check warning on line 128 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / tidy

'auto *mm' can be declared as 'const auto *mm' [readability-qualified-auto,-warnings-as-errors]

Check warning on line 128 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L128

Added line #L128 was not covered by tests

// Parameters are added as Bindings
const auto& param_names = mm->get_parameter_names();
for(const auto& param_name : param_names)
ins_to_name[mm->get_parameter(param_name)] = param_name;

Check warning on line 133 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L131-L133

Added lines #L131 - L133 were not covered by tests

json j = json::object();
auto& jlayers = j["Layers"] = json::array();

Check warning on line 136 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L135-L136

Added lines #L135 - L136 were not covered by tests

for(const auto ins : iterator_for(*mm))

Check warning on line 138 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L138

Added line #L138 was not covered by tests
{
// Skip these instructions to avoid clutter
static const std::vector<std::string> skip_instructions{
"@param", "check_context::migraphx::gpu::context", "hip::hip_allocate_memory", "load"};
if(contains(skip_instructions, ins->name()))
continue;

Check warning on line 144 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L142-L144

Added lines #L142 - L144 were not covered by tests

std::string ins_name;
if(ins->name() == "gpu::code_object")

Check warning on line 147 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L147

Added line #L147 was not covered by tests
{
ins_name =
ins->get_operator().to_value()["symbol_name"].without_key().to<std::string>();

Check warning on line 150 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L150

Added line #L150 was not covered by tests
}
// Differentiate literal broadcast from other broadcasts
else if(ins->name() == "broadcast" and ins->inputs().size() == 1 and
ins->inputs().front()->name() == "hip::hip_copy_literal")

Check warning on line 154 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L153-L154

Added lines #L153 - L154 were not covered by tests
{
ins_name = "broadcast_literal";
}
else
{
ins_name = ins->name();

Check warning on line 160 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L160

Added line #L160 was not covered by tests
}
auto count = name_count[ins_name]++;
ins_name = ins_name + "_" + std::to_string(count);

Check warning on line 163 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / tidy

string concatenation results in allocation of unnecessary temporary strings; consider using 'operator+=' or 'string::append()' instead [performance-inefficient-string-concatenation,-warnings-as-errors]

Check warning on line 163 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L162-L163

Added lines #L162 - L163 were not covered by tests
ins_to_name[ins] = ins_name;

// We don't want to show copy literal layers to avoid clutter
if(ins->name() == "hip::hip_copy_literal")

Check warning on line 167 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L167

Added line #L167 was not covered by tests
continue;

auto jlayer = json::object();
jlayer["Name"] = ins_to_name.at(ins);
jlayer["LayerType"] = ins->name();

Check warning on line 172 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L170-L172

Added lines #L170 - L172 were not covered by tests

auto& jlayer_inputs = jlayer["Inputs"] = json::array();
for(auto input : ins->inputs())

Check warning on line 175 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L174-L175

Added lines #L174 - L175 were not covered by tests
{
if(input->name() == "load")
continue;
if(input->name() == "hip::hip_copy_literal")

Check warning on line 179 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Empty if statement. [migraphx-EmptyIfStatement]

Check warning on line 179 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L177-L179

Added lines #L177 - L179 were not covered by tests
{
// TODO add this information to the layer params
}
auto jinput = json::object();
auto name_suffix = input->name() == "@param" ? "" : "_out";

Check warning on line 184 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / tidy

'auto name_suffix' can be declared as 'const auto *name_suffix' [readability-qualified-auto,-warnings-as-errors]
jinput["Name"] = ins_to_name.at(input) + name_suffix;

Check warning on line 185 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L183-L185

Added lines #L183 - L185 were not covered by tests
// TODO treat dynamic dims differently
jinput["Dimensions"] = input->get_shape().lens();
jinput["Format/Datatype"] = type_to_trt_type_string(input->get_shape().type());
jlayer_inputs.push_back(jinput);

Check warning on line 189 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L187-L189

Added lines #L187 - L189 were not covered by tests
}

auto& jlayer_outputs = jlayer["Outputs"] = json::array();
auto joutput = json::object();
joutput["Name"] = ins_name + "_out";
joutput["Dimensions"] = ins->get_shape().lens();
joutput["Format/Datatype"] = type_to_trt_type_string(ins->get_shape().type());
jlayer_outputs.push_back(joutput);

Check warning on line 197 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L192-L197

Added lines #L192 - L197 were not covered by tests

auto val = ins->get_operator().to_value();

Check warning on line 199 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L199

Added line #L199 was not covered by tests
static const std::vector<std::string> skip_keys{"code_object",
"shape",
"expected_inputs",
"output",
"symbol_name",
"literal",
"bytes",
"data",
"solution_object"};
for(const auto& v : val)

Check warning on line 209 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L208-L209

Added lines #L208 - L209 were not covered by tests
{
if(not contains(skip_keys, v.get_key()))

Check warning on line 211 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L211

Added line #L211 was not covered by tests
{
jlayer[v.get_key()] = v.without_key();

Check warning on line 213 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L213

Added line #L213 was not covered by tests
}
}

jlayers.push_back(jlayer);

Check warning on line 217 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L217

Added line #L217 was not covered by tests
}

// Bindings indicate inputs and outputs
auto& jbindings = j["Bindings"] = json::array();
for(const auto& param_name : param_names)

Check warning on line 222 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L221-L222

Added lines #L221 - L222 were not covered by tests
{
jbindings.push_back(param_name);

Check warning on line 224 in src/trt_json.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Consider using std::copy algorithm instead of a raw loop. [useStlAlgorithm]

Check warning on line 224 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L224

Added line #L224 was not covered by tests
}
// Bind return as output
jbindings.push_back(jlayers.back()["Outputs"][0]["Name"]);

Check warning on line 227 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L227

Added line #L227 was not covered by tests

return indent ? j.dump(*indent) : j.dump();

Check warning on line 229 in src/trt_json.cpp

View check run for this annotation

Codecov / codecov/patch

src/trt_json.cpp#L229

Added line #L229 was not covered by tests
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading