Skip to content

Commit

Permalink
Updated model connections dict to work for multi-output tensors for a…
Browse files Browse the repository at this point in the history
… layer. (#3704)

Signed-off-by: Sayanta Mukherjee <[email protected]>
  • Loading branch information
quic-ssayanta authored Dec 31, 2024
1 parent c8c0e24 commit b2dab1a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
# =============================================================================

""" Implementation to automatically prepare keras models for AIMET by converting them to a functional model """
import typing

import inspect
import logging
from typing import Any, Dict, List, Set, Union, Optional
Expand Down Expand Up @@ -402,6 +404,21 @@ def _get_call_kwargs(self, layer: tf.keras.layers.Layer) -> Dict[Union[KerasTens
return {}
return call_kwargs

def _update_nested_values(self, values: typing.List | typing.Tuple, old_to_new_tensor_mapper: typing.Dict):
"""
Helper function to update the nested Lists/Tuples based on the dictionary provided
:param values: List/Tuple of values
:param old_to_new_tensor_mapper: Update dictionary
:return: Any
"""
if isinstance(values, typing.Tuple):
return tuple(self._update_nested_values(x, old_to_new_tensor_mapper) for x in values)
if isinstance(values, typing.List):
return [self._update_nested_values(x, old_to_new_tensor_mapper) for x in values]
if isinstance(values, KerasTensor) and values.name in old_to_new_tensor_mapper:
return old_to_new_tensor_mapper[values.name]
return values

def _update_output_tensors_in_model_layers_connections(
self,
layer: tf.keras.layers.Layer,
Expand All @@ -415,8 +432,53 @@ def _update_output_tensors_in_model_layers_connections(
:param new_output_tensor: The new output tensor to update with
:param model: The model currently being checked. Used to add model outputs
"""
# pylint: disable=too-many-nested-blocks
if layer.name != new_output_tensor.name:
# pylint: disable=too-many-nested-blocks, disable=protected-access, too-many-locals, too-many-branches
if isinstance(new_output_tensor, List):
# Handling case where layer has list of output tensors
old_tensors = self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name]
# Updating layer_output_tensor_map to new tensors
self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name] = [tensor.name for tensor in new_output_tensor]

old_to_new_tensor_mapper = {old_tensor_name: new_output_tensor[i] for i, old_tensor_name in
enumerate(old_tensors)}

old_name_of_inputs = None
if layer.name in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES]:
old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop(
layer.name
)
for out_tensor in new_output_tensor:
new_name = out_tensor.name
if old_name_of_inputs is not None:
self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].update(
{new_name: old_name_of_inputs}
)
self.model_layers_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update(
{out_tensor.name: out_tensor}
)

# Update inbound_nodes with new tensor names
for node_name, values in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].items():
if layer.name in values:
# Update tensor for this entire layer
tensor_list = self._flatten_list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][node_name])
tensor_list.extend(self._flatten_list(list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS][node_name].values())))
tensor_list = [tensor for tensor in tensor_list if isinstance(tensor, KerasTensor)]
assert len(tensor_list) >= len(values), f"{node_name} has mismatched number of inbound nodes"
for idx, value in enumerate(values):
if value == layer.name and tensor_list[idx].name in old_to_new_tensor_mapper:
values[idx] = old_to_new_tensor_mapper[tensor_list[idx].name].name

# Update call_kwargs with new tensors
for _, kwargs_dict in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS].items():
for key, values in kwargs_dict.items():
kwargs_dict[key] = self._update_nested_values(values, old_to_new_tensor_mapper)

# Update call_args with new tensors
for key, args in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS].items():
self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][key] = self._update_nested_values(args, old_to_new_tensor_mapper)

elif layer.name != new_output_tensor.name:
new_name = new_output_tensor.name
old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop(
layer.name
Expand Down Expand Up @@ -598,6 +660,20 @@ def _get_keras_tensor_index(value: Any, search_list: List):
return idx
return None

def _flatten_list(self, values: typing.List | typing.Tuple):
"""
A helper function that returns flatten list of values in the given List or Tuple
:param values: List or Tuple of values
:return: List of flattened values
"""
flat_vals = []
for val in values:
if isinstance(val, (typing.List, typing.Tuple)):
flat_vals.extend(self._flatten_list(val))
else:
flat_vals.append(val)
return flat_vals

def _handle_normal_keras_layer(self, layer: tf.keras.layers.Layer) -> KerasTensor:
"""
Helper function to handle normal keras layers. This function will create a new output tensor for the layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ModelLayerConnectionsProperties(Enum):
OUTPUT_TENSORS = 'output_tensors'
CALL_ARGS = 'call_args'
CALL_KWARGS = 'call_kwargs'
LAYER_OUTPUT_TENSOR_MAP = 'layer_output_tensor_map'
TYPE = typing.Dict[typing.Dict[str, typing.List[str]],
typing.Dict[str, typing.Union[KerasTensor, typing.List[KerasTensor]]]]

Expand All @@ -72,8 +73,19 @@ def get_model_layers_connection_properties(model: tf.keras.Model) -> typing.Dict
model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.CALL_KWARGS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP] = OrderedDict()

for current_layer in model.layers:
# Handling the case where there's multiple outbound nodes with same outbound layer
output_tensors = current_layer.output
input_tensors = current_layer.input
if not isinstance(output_tensors, typing.List):
output_tensors = [output_tensors]
if not isinstance(input_tensors, typing.List):
input_tensors = [input_tensors]
model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][current_layer.name] = \
[tensor.name for tensor in output_tensors if hasattr(tensor, "name")]

for outbound_node in current_layer.outbound_nodes:
outbound_layers_name = outbound_node.outbound_layer.name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,37 @@ def test_cast_op_with_same_dtypes(self, set_dtype):
# Only two cast ops should be present as the original model has two cast ops
assert sum(num_cast_ops) == 2

def test_split_op_model(self):
inp = tf.keras.layers.Input(shape=(3, 224, 224))
conv = tf.keras.layers.Conv2D(filters=1, kernel_size=3, strides=1, padding="valid")(inp)
bn = tf.keras.layers.BatchNormalization()(conv)
split, split2 = tf.split(bn, num_or_size_splits=2, axis=2)
split_3, split4 = tf.split(bn, num_or_size_splits=2, axis=2)
concat = tf.concat([split2, split_3], axis=1)
concat_2 = tf.concat([split, split4, split_3], axis=1)
model = tf.keras.models.Model(inputs=inp, outputs=[concat_2, concat])

dummy_inp = np.random.randn(1, 3, 224, 224)

fp32_outs = model(dummy_inp)

folded_pairs, bn_folded_model = fold_all_batch_norms(model)

bn_folded_outs = bn_folded_model(dummy_inp)

for idx, fp32_out in enumerate(fp32_outs):
assert np.allclose(fp32_out, bn_folded_outs[idx], atol=1e-04)

split_0_output_tensor_names = [x.name for x in bn_folded_model.layers[2].output]
split_1_output_tensor_names = [x.name for x in bn_folded_model.layers[3].output]

assert bn_folded_model.layers[4].input[0].name == split_0_output_tensor_names[0]
assert bn_folded_model.layers[4].input[1].name == split_1_output_tensor_names[1]
assert bn_folded_model.layers[4].input[2].name == split_1_output_tensor_names[0]

assert bn_folded_model.layers[5].input[0].name == split_0_output_tensor_names[1]
assert bn_folded_model.layers[5].input[1].name == split_1_output_tensor_names[0]

@pytest.mark.skip("Possible Batch norms to fold is returning None?")
def test_fold_auto_mode_with_bn_after_Conv1d_layer(self):
input_shape = (2, 10, 32)
Expand Down

0 comments on commit b2dab1a

Please sign in to comment.