From ffa6f2e1e0db7eca1e2fa967f53e644c01fde339 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Sat, 21 Sep 2024 11:58:35 +0530 Subject: [PATCH] [IREE EP][Importer] Fix IR import for onnx.ConstantOfShape --- .../torch-mlir-import-onnx/OnnxImporter.cpp | 99 ++++++++++++++----- .../torch-mlir-import-onnx/OnnxImporter.h | 4 + 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index 0ba67cf33fd4c..b28054b4cf3b4 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() { for (const auto &node : nodes) { if (torch_mlir_onnx::failed(ImportNode(node))) { - return SetError("Failed to import node '" + node.name() + - "': " + "(node:\n" + node.DebugString() + "\n)"); + return failure(); } } @@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX node input '"; msg.append(input_name); - msg.append("'"); + msg.append("': "); + msg.append(node.DebugString()); return SetError(std::move(msg)); } input_values.push_back(found_it->second); @@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { for (auto &output_name : node.output()) { const onnx::TypeProto *type_proto = graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto(); - if (!type_proto) - return failure(); + if (!type_proto) { + return SetError("Failed to obtain TypeProto for tensor"); + } MlirType t = cc_.ConvertTypeProto(*type_proto); if (mlirTypeIsNull(t)) @@ -906,38 +907,77 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type, /*encoding*/ {nullptr}); }; + const bool has_raw_data = tensor_proto.has_raw_data(); MlirAttribute splat_attr = {nullptr}; + size_t out_size; switch (tensor_proto.data_type()) { - case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: { + const float *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } splat_attr = mlirDenseElementsAttrFloatSplatGet( - tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0)); + tensorTypeFor(mlirF32TypeGet(context_)), + has_raw_data ? data[0] : tensor_proto.float_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_INT32: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { + const int32_t *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } + splat_attr = mlirDenseElementsAttrInt32SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), - tensor_proto.int32_data(0)); + has_raw_data ? data[0] : tensor_proto.int32_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_INT64: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { + const int64_t *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } + splat_attr = mlirDenseElementsAttrInt64SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), - tensor_proto.int64_data(0)); + has_raw_data ? data[0] : tensor_proto.int64_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: - splat_attr = mlirDenseElementsAttrFloatSplatGet( - tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0)); + } + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: { + const double *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } + splat_attr = mlirDenseElementsAttrDoubleSplatGet( + tensorTypeFor(mlirF64TypeGet(context_)), + has_raw_data ? data[0] : tensor_proto.double_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { + const uint64_t *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } + splat_attr = mlirDenseElementsAttrUInt64SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), - tensor_proto.uint64_data(0)); + has_raw_data ? data[0] : tensor_proto.uint64_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: - // Special case: inline data is stored in uint64. - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + const uint32_t *data = {0}; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data); + } + splat_attr = mlirDenseElementsAttrUInt32SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)), - tensor_proto.uint64_data(0)); + has_raw_data ? data[0] : tensor_proto.float_data(0)); break; } + } if (mlirAttributeIsNull(splat_attr)) { std::string message = @@ -958,8 +998,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { toMlirNamedAttribute("value", splat_attr)); MlirValue result = mlirOperationGetResult(op, 0); - // Export to the nv_map. - auto inserted = nv_map_.insert(std::make_pair(name, result)); + auto inserted = nv_map_.emplace(node.output(0), result); if (!inserted.second) { std::string msg = "Multiple nodes produced a value for '"; msg.append(name); @@ -973,8 +1012,14 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { Status NodeImporter::GetImmediateShapeTensor(const std::string &name, std::vector &shape) { - const onnx::TensorProto &tp = - *graph_info_.graph_viewer().GetConstantInitializer(name, false); + const onnx::TensorProto *tensor = + graph_info_.graph_viewer().GetConstantInitializer(name, false); + if (!tensor) { + return SetError( + "Could not find immediate shape tensor in graph initializers"); + } + const onnx::TensorProto &tp = *tensor; + shape.clear(); // Since this is being interpreted as a shape, we only support some limited diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h index 733aefbda2583..e916da6b8a7d3 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h @@ -93,6 +93,10 @@ class GraphInfo { return nullptr; } + std::unordered_map & + value_info_map() { + return value_info_map_; + } std::vector &inputs() { return inputs_; } std::unordered_map & input_map() {