diff --git a/keras_core/ops/function.py b/keras_core/ops/function.py index b3f029efc..8662b12f4 100644 --- a/keras_core/ops/function.py +++ b/keras_core/ops/function.py @@ -4,6 +4,7 @@ from keras_core.api_export import keras_core_export from keras_core.backend import KerasTensor +from keras_core.backend.config import backend from keras_core.ops.operation import Operation from keras_core.utils.nest import pack_sequence_as @@ -46,10 +47,21 @@ class Function(Operation): def __init__(self, inputs, outputs, name=None): super().__init__(name=name) + if backend() == "tensorflow": + # Temporary work around for + # https://github.com/keras-team/keras-core/issues/931 + # This stop tensorflow from wrapping tf.function output in a + # _DictWrapper object. + _self_setattr_tracking = getattr( + self, "_self_setattr_tracking", True + ) + self._self_setattr_tracking = False self._inputs_struct = tree.map_structure(lambda x: x, inputs) self._outputs_struct = tree.map_structure(lambda x: x, outputs) self._inputs = tree.flatten(inputs) self._outputs = tree.flatten(outputs) + if backend() == "tensorflow": + self._self_setattr_tracking = _self_setattr_tracking (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph( self._inputs, self._outputs