Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras Model Preparer subclass model fixes #2489

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

# pylint: disable=wrong-import-position
from aimet_tensorflow.keras.utils.model_connection_utils import ModelLayerConnections, ModelLayerConnectionsProperties
from aimet_tensorflow.keras.utils.model_transform_utils import replace_separable_conv_with_depthwise_pointwise
from aimet_tensorflow.keras.utils.model_transform_utils import replace_separable_conv_with_depthwise_pointwise, replace_relu6_with_relu
from aimet_common.utils import AimetLogger

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.ModelPreparer)
Expand Down Expand Up @@ -165,7 +165,7 @@ def _get_class_names_in_model(model: Union[tf.keras.Model, tf.keras.layers.Layer
:param model: the 'layer' or 'model' to get the class name
:return: A set containing the class name
"""
return {regex_for_camel_case_to_snake_case.sub("_", model.__class__.__name__).lower()}
return {regex_for_camel_case_to_snake_case.sub("_", model.name).lower()}


def _is_nested_layer(layer: tf.keras.layers.Layer) -> bool:
Expand Down Expand Up @@ -610,7 +610,8 @@ def prepare_model(original_model: tf.keras.Model,
_logger.debug("Model does not contain any nested layers. "
"Returning original model after going through 'replace_separable_conv_with_depthwise_pointwise.")
model_to_return, _ = replace_separable_conv_with_depthwise_pointwise(original_model)
return original_model
model_to_return, _ = replace_relu6_with_relu(model_to_return)
return model_to_return

_logger.debug("Preparing model for AIMET. Original model architecture")
original_model.summary(print_fn=_logger.debug)
Expand All @@ -626,14 +627,15 @@ def prepare_model(original_model: tf.keras.Model,
K.clear_session() # To avoid name conflicts
model_to_return = tf.keras.models.clone_model(prepared_model)

# Extra prepare step to replace Separable Conv's with Depthwise Pointwise pattern if the prepared model
# had any in the original models nested layers.
model_to_return, _ = replace_separable_conv_with_depthwise_pointwise(model_to_return)

model_to_return.summary(print_fn=_logger.debug)

# Copying over weights from original model to functional model
_logger.debug("Final class_names: %s", class_names)
_set_functional_models_weights(original_model, model_to_return, class_names)

# Extra prepare step to replace Separable Conv's with Depthwise Pointwise pattern if the prepared model
# had any in the original models nested layers.
model_to_return, _ = replace_separable_conv_with_depthwise_pointwise(model_to_return)
model_to_return, _ = replace_relu6_with_relu(model_to_return)

return model_to_return
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def replacement(self, match_layer):
pointwise_layer = layers.Conv2D(
match_layer_config["filters"],
kernel_size=1, # Always 1 as per Keras source code
strides=match_layer_config["strides"],
strides=1,
padding="valid", # Always valid as per Keras source code
data_format=match_layer_config["data_format"],
dilation_rate=(1, 1), # Always (1, 1) as per Keras source code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,16 @@ def test_keras_text_classification_example_model_to_functional():
def test_non_nested_layered_model():
original_model = conv_functional()
random_input = np.random.rand(32, *original_model.input_shape[1:])
_ = original_model(random_input)
orig_output = original_model(random_input)

functional_model = prepare_model(original_model)

assert original_model == functional_model, \
"Prepare model did not give back the original model. This model does not need to be prepared."
functional_model_output = functional_model(random_input)
model_weights_in_correct_order = _get_original_models_weights_in_functional_model_order(
original_model, functional_model, class_names=set())

compare_weights(model_weights_in_correct_order, functional_model.get_weights())
np.testing.assert_array_equal(orig_output.numpy(), functional_model_output.numpy())

def test_multi_output():
class TestMultiOut(tf.keras.layers.Layer):
Expand Down Expand Up @@ -522,9 +526,13 @@ def test_multi_output_only_lambda():

original_model = tf.keras.Model(inputs=encoder_input, outputs=out)
random_input = np.random.rand(1, *original_model.input_shape[1:])
_ = original_model(random_input)
orig_output = original_model(random_input)

functional_model = prepare_model(original_model)

assert functional_model == original_model, "The original model does not contain any nested layers. \
The original model should be returned."
functional_model_output = functional_model(random_input)
model_weights_in_correct_order = _get_original_models_weights_in_functional_model_order(
original_model, functional_model, class_names=set())

compare_weights(model_weights_in_correct_order, functional_model.get_weights())
np.testing.assert_array_equal(orig_output.numpy(), functional_model_output.numpy())
Loading