Skip to content

Commit

Permalink
[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Sep 21, 2024
1 parent c45ebb0 commit ffa6f2e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() {

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
return failure();
}
}

Expand Down Expand Up @@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
msg.append(input_name);
msg.append("'");
msg.append("': ");
msg.append(node.DebugString());
return SetError(std::move(msg));
}
input_values.push_back(found_it->second);
Expand All @@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto)
return failure();
if (!type_proto) {
return SetError("Failed to obtain TypeProto for tensor");
}

MlirType t = cc_.ConvertTypeProto(*type_proto);
if (mlirTypeIsNull(t))
Expand Down Expand Up @@ -906,38 +907,77 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
/*encoding*/ {nullptr});
};
const bool has_raw_data = tensor_proto.has_raw_data();
MlirAttribute splat_attr = {nullptr};
size_t out_size;
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
const float *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
tensorTypeFor(mlirF32TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
const int32_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrInt32SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
tensor_proto.int32_data(0));
has_raw_data ? data[0] : tensor_proto.int32_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
const int64_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrInt64SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
tensor_proto.int64_data(0));
has_raw_data ? data[0] : tensor_proto.int64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
}
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
const double *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.double_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
const uint64_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
tensor_proto.uint64_data(0));
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
// Special case: inline data is stored in uint64.
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
const uint32_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
tensor_proto.uint64_data(0));
has_raw_data ? data[0] : tensor_proto.float_data(0));
break;
}
}

if (mlirAttributeIsNull(splat_attr)) {
std::string message =
Expand All @@ -958,8 +998,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
toMlirNamedAttribute("value", splat_attr));
MlirValue result = mlirOperationGetResult(op, 0);

// Export to the nv_map.
auto inserted = nv_map_.insert(std::make_pair(name, result));
auto inserted = nv_map_.emplace(node.output(0), result);
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
Expand All @@ -973,8 +1012,14 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
const onnx::TensorProto *tensor =
graph_info_.graph_viewer().GetConstantInitializer(name, false);
if (!tensor) {
return SetError(
"Could not find immediate shape tensor in graph initializers");
}
const onnx::TensorProto &tp = *tensor;

shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GraphInfo {
return nullptr;
}

std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
value_info_map() {
return value_info_map_;
}
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
input_map() {
Expand Down

0 comments on commit ffa6f2e

Please sign in to comment.