Skip to content

Commit

Permalink
Forego version adaption of inlined models if no nodes are from the de…
Browse files Browse the repository at this point in the history
…fault domain (#105)
  • Loading branch information
cbourjau authored Oct 5, 2023
1 parent 83bd42e commit 4c5ac05
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def adapt_inline(
source_version = max({v for d, v in node.opset_req if d in ("", "ai.onnx")})
target_version = target_opsets[""]

# convert_version fails if the inlined model does not import the default domain
seen_domains = {prot.domain for prot in protos}
if not seen_domains & {"", "ai.onnx"}:
return protos
if source_version != target_version:
target_model = onnx.version_converter.convert_version(
node.model, target_version
Expand Down
84 changes: 84 additions & 0 deletions tests/test_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,87 @@ def test_adapt_node_with_repeating_input_names():
c = op19.identity(a)

build({"a": a}, {"b": b, "c": c})


def test_inline_model_custom_node_only():
"""Inline a model which only consists of a custom node.
Such models do not import from the default domain.
"""
domain = "foo.ai"
node = onnx.helper.make_node("FooOp", ["a"], ["b"], domain=domain)
value_infos_input = [
onnx.helper.make_value_info(
"a", onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ("N",))
),
]
value_infos_output = [
onnx.helper.make_value_info(
"b", onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ("N",))
)
]

model = onnx.helper.make_model(
onnx.helper.make_graph(
[node],
"graph",
value_infos_input,
value_infos_output,
),
opset_imports=[onnx.helper.make_opsetid(domain, 1)],
)

# Ensure that our model is valid
onnx.checker.check_model(model, full_check=True)

(a,) = arguments(data=Tensor(numpy.str_, ("N",)))
(b,) = inline(model)(a).values()

# Add another node to the model to trigger the adaption logic
c = op18.identity(b)
build({"a": a}, {"c": c})


@pytest.mark.skip(
reason="Adapting custom nodes (including their subgraphs) is currently not supported"
)
def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto):
"""A singleton custom node with a old standard node in its attribute."""
domain = "foo.ai"

node = onnx.helper.make_node(
"FooOp", ["a"], ["b"], domain=domain, **{"nested_graph": old_squeeze.graph}
)
value_infos_input = [
onnx.helper.make_value_info(
"a", onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, ("N",))
),
]
value_infos_output = [
onnx.helper.make_value_info(
"b", onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, ("N",))
)
]

model = onnx.helper.make_model(
onnx.helper.make_graph(
[node],
"graph",
value_infos_input,
value_infos_output,
),
opset_imports=[
onnx.helper.make_opsetid(domain, 1),
onnx.helper.make_opsetid("", 12),
],
)

# Ensure that our model is valid
onnx.checker.check_model(model, full_check=True)

(a,) = arguments(data=Tensor(numpy.float32, ("N",)))
(b,) = inline(model)(a).values()

# Add another node to the model to trigger the adaption logic
c = op18.identity(b)
build({"a": a}, {"c": c})

0 comments on commit 4c5ac05

Please sign in to comment.