Skip to content

Commit

Permalink
Only propagate produced values (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Oct 23, 2023
1 parent 44388d5 commit 67dae6c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 15 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ Change log
0.9.3 (Unreleased)
------------------

**Bug fix**
**Bug fixes**

- Address missing Value Infos when building singleton model for shape inference.
- Fix issue where Value Propagation failure prevents model creation/inlining.


0.9.2 (2023-10-20)
Expand Down
1 change: 1 addition & 0 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def propagate_values(self) -> Dict[str, _value_prop.PropValueType]:
return {
f"outputs_{k}": unwrap_feed(var.unwrap_type(), output_feed[o.name]).value
for k, (o, var) in enumerate(zip(self.graph.output, self.outputs.outputs))
if o.name in output_feed
}

def to_onnx(
Expand Down
13 changes: 3 additions & 10 deletions src/spox/_standard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Module implementing a base for standard ONNX operators, which use the functionality of ONNX node-level inference."""
import logging
from typing import TYPE_CHECKING, Callable, Dict, Tuple

import numpy
Expand Down Expand Up @@ -175,15 +174,9 @@ def propagate_values_onnx(self) -> Dict[str, PropValueType]:
for var in self.inputs.get_vars().values()
if var._value
}
try:
output_feed = run(model, input_feed)
except Exception as e:
logging.debug(
f"Value propagation in {self.get_op_repr()} on backend "
f"{_value_prop._VALUE_PROP_BACKEND} failed with - "
f"{type(e).__name__}: {e}"
)
output_feed = {}

output_feed = run(model, input_feed)

results = {
scope.var[str(name)]
._which_output: unwrap_feed(scope.var[str(name)].unwrap_type(), result)
Expand Down
20 changes: 16 additions & 4 deletions src/spox/_value_prop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import logging
import warnings
from dataclasses import dataclass
from typing import Dict, List, Union
Expand Down Expand Up @@ -155,8 +156,12 @@ def _run_reference_implementation(
try:
session = onnx.reference.ReferenceEvaluator(model)
output_feed = dict(zip(session.output_names, session.run(None, input_feed)))
except NotImplementedError:
except Exception as e:
# Give up on value propagation if an implementation is missing.
logging.debug(
f"Value propagation in {model} on the ONNX reference implementation failed with - "
f"{type(e).__name__}: {e}"
)
return {}
return output_feed

Expand All @@ -169,9 +174,16 @@ def _run_onnxruntime(
# Silence possible warnings during execution (especially constant folding)
options = onnxruntime.SessionOptions()
options.log_severity_level = 3
session = onnxruntime.InferenceSession(model.SerializeToString(), options)
output_names = [output.name for output in session.get_outputs()]
output_feed = dict(zip(output_names, session.run(None, input_feed)))
try:
session = onnxruntime.InferenceSession(model.SerializeToString(), options)
output_names = [output.name for output in session.get_outputs()]
output_feed = dict(zip(output_names, session.run(None, input_feed)))
except Exception as e:
logging.debug(
f"Value propagation in {model} on the onnxruntime failed with - "
f"{type(e).__name__}: {e}"
)
return {}
return output_feed


Expand Down
49 changes: 49 additions & 0 deletions tests/test_value_propagation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy
import onnx
import pytest

import spox
import spox._future
import spox.opset.ai.onnx.ml.v3 as ml
import spox.opset.ai.onnx.v17 as op
from spox import Var, _type_system
Expand All @@ -8,6 +12,16 @@
from spox._value_prop import ORTValue, PropValue


@pytest.fixture(
params=[
spox._future.ValuePropBackend.ONNXRUNTIME,
spox._future.ValuePropBackend.REFERENCE,
]
)
def value_prop_backend(request):
return request.param


def dummy_var(typ=None, value=None):
"""Function for creating a ``var`` without an operator but with a type and value."""
return Var(None, typ, value) # type: ignore
Expand Down Expand Up @@ -158,3 +172,38 @@ def test_propagated_value_does_not_alias_dtype():
x = numpy.iinfo(numpy.int64).max + 1
# Without the explicit astype(uint64), x actually ends up being ulonglong
assert_equal_value(op.const(x), numpy.array(x).astype(numpy.uint64))


def test_value_propagation_does_not_fail_on_unseen_opsets(value_prop_backend):
spox._future.set_value_prop_backend(value_prop_backend)

model_input = [onnx.helper.make_tensor_value_info("X", elem_type=8, shape=("X",))]
model_output = [
onnx.helper.make_tensor_value_info("y", elem_type=8, shape=("y", "max_words"))
]

nodes = [
onnx.helper.make_node(
"RandomNode",
inputs=["X"],
outputs=["y"],
domain="com.hello",
)
]

graph = onnx.helper.make_graph(
nodes,
"RandomNode",
model_input,
model_output,
)

model = onnx.helper.make_model(
graph,
opset_imports=[
onnx.helper.make_opsetid("", 18),
onnx.helper.make_opsetid("com.hello", 1),
],
)

spox.inline(model)(X=op.const(["Test Test"], dtype=numpy.str_))

0 comments on commit 67dae6c

Please sign in to comment.