Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated model connections dict to work for multi-output tensors for a layer #3704

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
# =============================================================================

""" Implementation to automatically prepare keras models for AIMET by converting them to a functional model """
import typing

import inspect
import logging
from typing import Any, Dict, List, Set, Union, Optional
Expand Down Expand Up @@ -402,6 +404,21 @@ def _get_call_kwargs(self, layer: tf.keras.layers.Layer) -> Dict[Union[KerasTens
return {}
return call_kwargs

def _update_nested_values(self, values: typing.List | typing.Tuple, old_to_new_tensor_mapper: typing.Dict):
"""
Helper function to update the nested Lists/Tuples based on the dictionary provided
:param values: List/Tuple of values
:param old_to_new_tensor_mapper: Update dictionary
:return: Any
"""
if isinstance(values, typing.Tuple):
return tuple(self._update_nested_values(x, old_to_new_tensor_mapper) for x in values)
if isinstance(values, typing.List):
return [self._update_nested_values(x, old_to_new_tensor_mapper) for x in values]
if isinstance(values, KerasTensor) and values.name in old_to_new_tensor_mapper:
return old_to_new_tensor_mapper[values.name]
return values

def _update_output_tensors_in_model_layers_connections(
self,
layer: tf.keras.layers.Layer,
Expand All @@ -415,8 +432,53 @@ def _update_output_tensors_in_model_layers_connections(
:param new_output_tensor: The new output tensor to update with
:param model: The model currently being checked. Used to add model outputs
"""
# pylint: disable=too-many-nested-blocks
if layer.name != new_output_tensor.name:
# pylint: disable=too-many-nested-blocks, disable=protected-access, too-many-locals, too-many-branches
if isinstance(new_output_tensor, List):
# Handling case where layer has list of output tensors
old_tensors = self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name]
# Updating layer_output_tensor_map to new tensors
self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name] = [tensor.name for tensor in new_output_tensor]

old_to_new_tensor_mapper = {old_tensor_name: new_output_tensor[i] for i, old_tensor_name in
enumerate(old_tensors)}

old_name_of_inputs = None
if layer.name in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES]:
old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop(
layer.name
)
for out_tensor in new_output_tensor:
new_name = out_tensor.name
if old_name_of_inputs is not None:
self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].update(
{new_name: old_name_of_inputs}
)
self.model_layers_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update(
{out_tensor.name: out_tensor}
)

# Update inbound_nodes with new tensor names
for node_name, values in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].items():
if layer.name in values:
# Update tensor for this entire layer
tensor_list = self._flatten_list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][node_name])
tensor_list.extend(self._flatten_list(list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS][node_name].values())))
tensor_list = [tensor for tensor in tensor_list if isinstance(tensor, KerasTensor)]
assert len(tensor_list) >= len(values), f"{node_name} has mismatched number of inbound nodes"
for idx, value in enumerate(values):
if value == layer.name and tensor_list[idx].name in old_to_new_tensor_mapper:
values[idx] = old_to_new_tensor_mapper[tensor_list[idx].name].name

# Update call_kwargs with new tensors
for _, kwargs_dict in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS].items():
for key, values in kwargs_dict.items():
kwargs_dict[key] = self._update_nested_values(values, old_to_new_tensor_mapper)

# Update call_args with new tensors
for key, args in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS].items():
self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][key] = self._update_nested_values(args, old_to_new_tensor_mapper)

elif layer.name != new_output_tensor.name:
new_name = new_output_tensor.name
old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop(
layer.name
Expand Down Expand Up @@ -598,6 +660,20 @@ def _get_keras_tensor_index(value: Any, search_list: List):
return idx
return None

def _flatten_list(self, values: typing.List | typing.Tuple):
"""
A helper function that returns flatten list of values in the given List or Tuple
:param values: List or Tuple of values
:return: List of flattened values
"""
flat_vals = []
for val in values:
if isinstance(val, (typing.List, typing.Tuple)):
flat_vals.extend(self._flatten_list(val))
else:
flat_vals.append(val)
return flat_vals

def _handle_normal_keras_layer(self, layer: tf.keras.layers.Layer) -> KerasTensor:
"""
Helper function to handle normal keras layers. This function will create a new output tensor for the layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ModelLayerConnectionsProperties(Enum):
OUTPUT_TENSORS = 'output_tensors'
CALL_ARGS = 'call_args'
CALL_KWARGS = 'call_kwargs'
LAYER_OUTPUT_TENSOR_MAP = 'layer_output_tensor_map'
TYPE = typing.Dict[typing.Dict[str, typing.List[str]],
typing.Dict[str, typing.Union[KerasTensor, typing.List[KerasTensor]]]]

Expand All @@ -72,8 +73,19 @@ def get_model_layers_connection_properties(model: tf.keras.Model) -> typing.Dict
model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.CALL_KWARGS] = OrderedDict()
model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP] = OrderedDict()

for current_layer in model.layers:
# Handling the case where there's multiple outbound nodes with same outbound layer
output_tensors = current_layer.output
input_tensors = current_layer.input
if not isinstance(output_tensors, typing.List):
output_tensors = [output_tensors]
if not isinstance(input_tensors, typing.List):
input_tensors = [input_tensors]
model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][current_layer.name] = \
[tensor.name for tensor in output_tensors if hasattr(tensor, "name")]

for outbound_node in current_layer.outbound_nodes:
outbound_layers_name = outbound_node.outbound_layer.name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,37 @@ def test_cast_op_with_same_dtypes(self, set_dtype):
# Only two cast ops should be present as the original model has two cast ops
assert sum(num_cast_ops) == 2

def test_split_op_model(self):
inp = tf.keras.layers.Input(shape=(3, 224, 224))
conv = tf.keras.layers.Conv2D(filters=1, kernel_size=3, strides=1, padding="valid")(inp)
bn = tf.keras.layers.BatchNormalization()(conv)
split, split2 = tf.split(bn, num_or_size_splits=2, axis=2)
split_3, split4 = tf.split(bn, num_or_size_splits=2, axis=2)
concat = tf.concat([split2, split_3], axis=1)
concat_2 = tf.concat([split, split4, split_3], axis=1)
model = tf.keras.models.Model(inputs=inp, outputs=[concat_2, concat])

dummy_inp = np.random.randn(1, 3, 224, 224)

fp32_outs = model(dummy_inp)

folded_pairs, bn_folded_model = fold_all_batch_norms(model)

bn_folded_outs = bn_folded_model(dummy_inp)

for idx, fp32_out in enumerate(fp32_outs):
assert np.allclose(fp32_out, bn_folded_outs[idx], atol=1e-04)

split_0_output_tensor_names = [x.name for x in bn_folded_model.layers[2].output]
split_1_output_tensor_names = [x.name for x in bn_folded_model.layers[3].output]

assert bn_folded_model.layers[4].input[0].name == split_0_output_tensor_names[0]
assert bn_folded_model.layers[4].input[1].name == split_1_output_tensor_names[1]
assert bn_folded_model.layers[4].input[2].name == split_1_output_tensor_names[0]

assert bn_folded_model.layers[5].input[0].name == split_0_output_tensor_names[1]
assert bn_folded_model.layers[5].input[1].name == split_1_output_tensor_names[0]

@pytest.mark.skip("Possible Batch norms to fold is returning None?")
def test_fold_auto_mode_with_bn_after_Conv1d_layer(self):
input_shape = (2, 10, 32)
Expand Down
Loading