From cc35274e6dd8878792ee92edf3bbfc388ea5fec4 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Wed, 20 Nov 2024 18:43:34 +0200 Subject: [PATCH] Remove initializer in node adaption logic --- src/spox/_adapt.py | 8 -------- tests/test_adapt.py | 13 +++++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 32017dab..bdd3cb28 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -4,7 +4,6 @@ import warnings from typing import Optional -import numpy as np import onnx import onnx.version_converter @@ -14,7 +13,6 @@ from ._node import Node from ._schemas import SCHEMAS from ._scope import Scope -from ._utils import from_array from ._var import Var @@ -43,11 +41,6 @@ def adapt_node( ) for key, var in node.outputs.get_vars().items() ] - initializers = [ - from_array(var._value, name) # type: ignore - for name, var in node.inputs.get_vars().items() - if isinstance(var._value, np.ndarray) - ] except ValueError: return None @@ -57,7 +50,6 @@ def adapt_node( "spox__singleton_adapter_graph", list(input_info.values()), output_info, - initializers, ), opset_imports=[onnx.helper.make_operatorsetid("", source_version)], ) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 3b0884ad..e552110c 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -15,6 +15,7 @@ from spox import Tensor, Var, argument, build, inline from spox._attributes import AttrInt64s from spox._fields import BaseAttributes, BaseInputs, BaseOutputs +from spox._future import initializer from spox._graph import arguments, results from spox._node import OpType from spox._standard import StandardNode @@ -157,6 +158,18 @@ def test_adapt_node_with_repeating_input_names(): build({"a": a}, {"b": b, "c": c}) +def test_adapt_node_initializer(): + init_data = [1.0, 2.0, 3.0] + + a = argument(Tensor(np.float32, ("N",))) + b = initializer(init_data, np.float32) + c = op18.equal(a, b) + d = op19.identity(a) + + model = build({"a": a}, {"b": b, "c": c, "d": d}) + np.testing.assert_allclose(model.graph.initializer[0].float_data, init_data) + + def test_inline_model_custom_node_only(): """Inline a model which only consists of a custom node.