diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/model_preparer.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/model_preparer.py index a6c1da23fd6..96768b2eef5 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/model_preparer.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/model_preparer.py @@ -36,6 +36,8 @@ # ============================================================================= """ Implementation to automatically prepare keras models for AIMET by converting them to a functional model """ +import typing + import inspect import logging from typing import Any, Dict, List, Set, Union, Optional @@ -402,6 +404,21 @@ def _get_call_kwargs(self, layer: tf.keras.layers.Layer) -> Dict[Union[KerasTens return {} return call_kwargs + def _update_nested_values(self, values: typing.List | typing.Tuple, old_to_new_tensor_mapper: typing.Dict): + """ + Helper function to update the nested Lists/Tuples based on the dictionary provided + :param values: List/Tuple of values + :param old_to_new_tensor_mapper: Update dictionary + :return: Any + """ + if isinstance(values, typing.Tuple): + return tuple(self._update_nested_values(x, old_to_new_tensor_mapper) for x in values) + if isinstance(values, typing.List): + return [self._update_nested_values(x, old_to_new_tensor_mapper) for x in values] + if isinstance(values, KerasTensor) and values.name in old_to_new_tensor_mapper: + return old_to_new_tensor_mapper[values.name] + return values + def _update_output_tensors_in_model_layers_connections( self, layer: tf.keras.layers.Layer, @@ -415,8 +432,53 @@ def _update_output_tensors_in_model_layers_connections( :param new_output_tensor: The new output tensor to update with :param model: The model currently being checked. Used to add model outputs """ - # pylint: disable=too-many-nested-blocks - if layer.name != new_output_tensor.name: + # pylint: disable=too-many-nested-blocks, disable=protected-access, too-many-locals, too-many-branches + if isinstance(new_output_tensor, List): + # Handling case where layer has list of output tensors + old_tensors = self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name] + # Updating layer_output_tensor_map to new tensors + self.model_layers_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][layer.name] = [tensor.name for tensor in new_output_tensor] + + old_to_new_tensor_mapper = {old_tensor_name: new_output_tensor[i] for i, old_tensor_name in + enumerate(old_tensors)} + + old_name_of_inputs = None + if layer.name in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES]: + old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop( + layer.name + ) + for out_tensor in new_output_tensor: + new_name = out_tensor.name + if old_name_of_inputs is not None: + self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].update( + {new_name: old_name_of_inputs} + ) + self.model_layers_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update( + {out_tensor.name: out_tensor} + ) + + # Update inbound_nodes with new tensor names + for node_name, values in self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].items(): + if layer.name in values: + # Update tensor for this entire layer + tensor_list = self._flatten_list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][node_name]) + tensor_list.extend(self._flatten_list(list(self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS][node_name].values()))) + tensor_list = [tensor for tensor in tensor_list if isinstance(tensor, KerasTensor)] + assert len(tensor_list) >= len(values), f"{node_name} has mismatched number of inbound nodes" + for idx, value in enumerate(values): + if value == layer.name and tensor_list[idx].name in old_to_new_tensor_mapper: + values[idx] = old_to_new_tensor_mapper[tensor_list[idx].name].name + + # Update call_kwargs with new tensors + for _, kwargs_dict in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_KWARGS].items(): + for key, values in kwargs_dict.items(): + kwargs_dict[key] = self._update_nested_values(values, old_to_new_tensor_mapper) + + # Update call_args with new tensors + for key, args in self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS].items(): + self.model_layers_connections[ModelLayerConnectionsProperties.CALL_ARGS][key] = self._update_nested_values(args, old_to_new_tensor_mapper) + + elif layer.name != new_output_tensor.name: new_name = new_output_tensor.name old_name_of_inputs = self.model_layers_connections[ModelLayerConnectionsProperties.INBOUND_NODES].pop( layer.name @@ -598,6 +660,20 @@ def _get_keras_tensor_index(value: Any, search_list: List): return idx return None + def _flatten_list(self, values: typing.List | typing.Tuple): + """ + A helper function that returns flatten list of values in the given List or Tuple + :param values: List or Tuple of values + :return: List of flattened values + """ + flat_vals = [] + for val in values: + if isinstance(val, (typing.List, typing.Tuple)): + flat_vals.extend(self._flatten_list(val)) + else: + flat_vals.append(val) + return flat_vals + def _handle_normal_keras_layer(self, layer: tf.keras.layers.Layer) -> KerasTensor: """ Helper function to handle normal keras layers. This function will create a new output tensor for the layer diff --git a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/utils/model_connection_utils.py b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/utils/model_connection_utils.py index 1833b9db0f0..90cf0813c80 100644 --- a/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/utils/model_connection_utils.py +++ b/TrainingExtensions/tensorflow/src/python/aimet_tensorflow/keras/utils/model_connection_utils.py @@ -49,6 +49,7 @@ class ModelLayerConnectionsProperties(Enum): OUTPUT_TENSORS = 'output_tensors' CALL_ARGS = 'call_args' CALL_KWARGS = 'call_kwargs' + LAYER_OUTPUT_TENSOR_MAP = 'layer_output_tensor_map' TYPE = typing.Dict[typing.Dict[str, typing.List[str]], typing.Dict[str, typing.Union[KerasTensor, typing.List[KerasTensor]]]] @@ -72,8 +73,19 @@ def get_model_layers_connection_properties(model: tf.keras.Model) -> typing.Dict model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS] = OrderedDict() model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS] = OrderedDict() model_layer_connections[ModelLayerConnectionsProperties.CALL_KWARGS] = OrderedDict() + model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP] = OrderedDict() for current_layer in model.layers: + # Handling the case where there's multiple outbound nodes with same outbound layer + output_tensors = current_layer.output + input_tensors = current_layer.input + if not isinstance(output_tensors, typing.List): + output_tensors = [output_tensors] + if not isinstance(input_tensors, typing.List): + input_tensors = [input_tensors] + model_layer_connections[ModelLayerConnectionsProperties.LAYER_OUTPUT_TENSOR_MAP][current_layer.name] = \ + [tensor.name for tensor in output_tensors if hasattr(tensor, "name")] + for outbound_node in current_layer.outbound_nodes: outbound_layers_name = outbound_node.outbound_layer.name diff --git a/TrainingExtensions/tensorflow/test/python/test_batch_norm_fold_keras.py b/TrainingExtensions/tensorflow/test/python/test_batch_norm_fold_keras.py index 431e3792281..838d5e49cac 100644 --- a/TrainingExtensions/tensorflow/test/python/test_batch_norm_fold_keras.py +++ b/TrainingExtensions/tensorflow/test/python/test_batch_norm_fold_keras.py @@ -1691,6 +1691,37 @@ def test_cast_op_with_same_dtypes(self, set_dtype): # Only two cast ops should be present as the original model has two cast ops assert sum(num_cast_ops) == 2 + def test_split_op_model(self): + inp = tf.keras.layers.Input(shape=(3, 224, 224)) + conv = tf.keras.layers.Conv2D(filters=1, kernel_size=3, strides=1, padding="valid")(inp) + bn = tf.keras.layers.BatchNormalization()(conv) + split, split2 = tf.split(bn, num_or_size_splits=2, axis=2) + split_3, split4 = tf.split(bn, num_or_size_splits=2, axis=2) + concat = tf.concat([split2, split_3], axis=1) + concat_2 = tf.concat([split, split4, split_3], axis=1) + model = tf.keras.models.Model(inputs=inp, outputs=[concat_2, concat]) + + dummy_inp = np.random.randn(1, 3, 224, 224) + + fp32_outs = model(dummy_inp) + + folded_pairs, bn_folded_model = fold_all_batch_norms(model) + + bn_folded_outs = bn_folded_model(dummy_inp) + + for idx, fp32_out in enumerate(fp32_outs): + assert np.allclose(fp32_out, bn_folded_outs[idx], atol=1e-04) + + split_0_output_tensor_names = [x.name for x in bn_folded_model.layers[2].output] + split_1_output_tensor_names = [x.name for x in bn_folded_model.layers[3].output] + + assert bn_folded_model.layers[4].input[0].name == split_0_output_tensor_names[0] + assert bn_folded_model.layers[4].input[1].name == split_1_output_tensor_names[1] + assert bn_folded_model.layers[4].input[2].name == split_1_output_tensor_names[0] + + assert bn_folded_model.layers[5].input[0].name == split_0_output_tensor_names[1] + assert bn_folded_model.layers[5].input[1].name == split_1_output_tensor_names[0] + @pytest.mark.skip("Possible Batch norms to fold is returning None?") def test_fold_auto_mode_with_bn_after_Conv1d_layer(self): input_shape = (2, 10, 32)