Skip to content

Commit

Permalink
Hacky fix for dictionary output with tf 2.14 (#933)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Sep 20, 2023
1 parent ff60e34 commit aa270a2
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions keras_core/ops/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
from keras_core.backend.config import backend
from keras_core.ops.operation import Operation
from keras_core.utils.nest import pack_sequence_as

Expand Down Expand Up @@ -46,10 +47,21 @@ class Function(Operation):
def __init__(self, inputs, outputs, name=None):
super().__init__(name=name)

if backend() == "tensorflow":
# Temporary work around for
# https://github.com/keras-team/keras-core/issues/931
# This stop tensorflow from wrapping tf.function output in a
# _DictWrapper object.
_self_setattr_tracking = getattr(
self, "_self_setattr_tracking", True
)
self._self_setattr_tracking = False
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
self._inputs = tree.flatten(inputs)
self._outputs = tree.flatten(outputs)
if backend() == "tensorflow":
self._self_setattr_tracking = _self_setattr_tracking

(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
self._inputs, self._outputs
Expand Down

0 comments on commit aa270a2

Please sign in to comment.