From 9488f6a137277c4a5fed7fec3cacde0a046c80ab Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Tue, 19 Nov 2024 20:34:46 +0200 Subject: [PATCH] Fix adapt node --- src/spox/_adapt.py | 21 ++++++++++++++------- tests/test_adapt.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 8c95b2d..ff84817 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -9,10 +9,11 @@ from ._attributes import AttrGraph from ._inline import _Inline -from ._internal_op import _InternalNode +from ._internal_op import _Initializer, _InternalNode from ._node import Node from ._schemas import SCHEMAS from ._scope import Scope +from ._utils import from_array from ._var import VarInfo @@ -30,16 +31,21 @@ def adapt_node( # By using a dictionary we ensure that we only have a single # ValueInfo per (possibly repeated) input name. input_info = { - var_names[var]: var.unwrap_type()._to_onnx_value_info( - var_names[var], _traceback_name=f"adapt-input {key}" + var_names[var_info]: var_info.unwrap_type()._to_onnx_value_info( + var_names[var_info], _traceback_name=f"adapt-input {key}" ) - for key, var in node.inputs.get_var_infos().items() + for key, var_info in node.inputs.get_var_infos().items() } output_info = [ - var.unwrap_type()._to_onnx_value_info( - var_names[var], _traceback_name=f"adapt-output {key}" + var_info.unwrap_type()._to_onnx_value_info( + var_names[var_info], _traceback_name=f"adapt-output {key}" ) - for key, var in node.outputs.get_var_infos().items() + for key, var_info in node.outputs.get_var_infos().items() + ] + initializers = [ + from_array(var_info._op.attrs.get_fields()["value"].value, name) # type: ignore + for name, var_info in node.inputs.get_var_infos().items() + if isinstance(var_info._op, _Initializer) ] except ValueError: return None @@ -50,6 +56,7 @@ def adapt_node( "spox__singleton_adapter_graph", list(input_info.values()), output_info, + initializers, ), opset_imports=[onnx.helper.make_operatorsetid("", source_version)], ) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 25c4764..a90e077 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -15,6 +15,7 @@ from spox import Tensor, Var, argument, build, inline from spox._attributes import AttrInt64s from spox._fields import BaseAttributes, BaseInputs, BaseOutputs +from spox._future import initializer from spox._graph import arguments, results from spox._node import OpType from spox._standard import StandardNode @@ -163,6 +164,18 @@ def test_adapt_node_with_repeating_input_names(): build({"a": a}, {"b": b, "c": c}) +def test_adapt_node_initializer(): + init_data = [1.0, 2.0, 3.0] + + a = argument(Tensor(np.float32, ("N",))) + b = initializer(init_data, np.float32) + c = op18.equal(a, b) + d = op19.identity(a) + + model = build({"a": a}, {"b": b, "c": c, "d": d}) + np.testing.assert_allclose(model.graph.initializer[0].float_data, init_data) + + def test_inline_model_custom_node_only(): """Inline a model which only consists of a custom node.