diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py index 73a89974e50..8f3d93e2e1e 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/meta/connectedgraph.py @@ -483,18 +483,25 @@ def _link_branch_op_to_multiple_ops(self, branch_op: Op, product_list: list, # 4: Remove product from original op to child in child's inputs # 5: Remove product from self._products for product in product_list: - assert len(product.consumers) == 1 - # item 1 - branch_op_product.add_consumer(product.consumers[0]) - # item 2 - assert len(product.tensor_dict.keys()) == 1 - branch_op_product.tensor_dict[product.consumers[0]] = product.tensor_dict[product.consumers[0]] - # items 3 and 4 - # replace the old product with the new branch product, in the same index as the old product - index = product.consumers[0].inputs.index(product) - product.consumers[0].inputs[index] = branch_op_product - # item 5 - del self._products[product.name] + assert len(product.consumers) <= 1 + + if len(product.consumers) == 1: + # item 1 + branch_op_product.add_consumer(product.consumers[0]) + # item 2 + assert len(product.tensor_dict.keys()) == 1 + branch_op_product.tensor_dict[product.consumers[0]] = product.tensor_dict[product.consumers[0]] + # items 3 and 4 + # replace the old product with the new branch product, in the same index as the old product + index = product.consumers[0].inputs.index(product) + product.consumers[0].inputs[index] = branch_op_product + # item 5 + del self._products[product.name] + else: + for output in self.model.graph.output: + if output.name in product.name: + del self._products[product.name] + self._create_link_for_output_product(output.name, branch_op.name) self._products[branch_op_product.name] = branch_op_product diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 724b4a95126..23e978a18f6 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -1049,6 +1049,19 @@ def multi_input_model(): model = ONNXModel(load_model('./model_multi_input.onnx')) return model +def multi_output_model(): + model = MultipleOutputModel() + sample_input = np.random.rand(128, 3, 32, 32).astype(np.float32) + + onnx_filename = "/tmp/dummy_model_multiple_outputs.onnx" + input_names = ["input"] + output_names = ["output_mul", "output_add"] + torch.onnx.export(model, torch.as_tensor(sample_input), onnx_filename, verbose=True, input_names=input_names, + output_names=output_names) + + model = ONNXModel(load_model(onnx_filename)) + return model + def transposed_conv_model(): x = torch.randn(10, 10, 4, 4, requires_grad=True) model = TransposedConvModel() @@ -1417,6 +1430,51 @@ def forward(self, x): return x +class MultipleOutputModel(SingleResidual): + """ + Model + """ + def __init__(self): + super().__init__() + # change padding size to 0, onnxruntime only support input size is the factor of output size for pooling + self.conv4 = torch.nn.Conv2d(32, 8, kernel_size=2, stride=2, padding=0, bias=True) + self.fc2 = torch.nn.Linear(10, 3) + # remove bn layer for currently not supporting non-4 dim param tensors + del self.bn1 + del self.bn2 + + def forward(self, inputs): + x = self.conv1(inputs) + # TODO + # remove bn layer for currently not supporting non-4 dim param tensors + # x = self.bn1(x) + x = self.relu1(x) + x = self.maxpool(x) + + # Save the output of MaxPool as residual. + residual = x + + x = self.conv2(x) + # TODO + # remove bn layer for currently not supporting non-4 dim param tensors + # x = self.bn2(x) + x = self.relu2(x) + x = self.conv3(x) + + # Add the residual + # AdaptiveAvgPool2d is used to get the desired dimension before adding. + residual = self.conv4(residual) + residual = self.ada(residual) + x += residual + x = self.relu3(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + y = self.fc2(x) + + return x, y + def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'): torch.onnx.export(model.eval(), diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 3bdc489d1e8..509d669337e 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -50,8 +50,7 @@ from aimet_onnx.qc_quantize_op import OpMode from aimet_onnx.utils import make_dummy_input from models.models_for_tests import SingleResidual -from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model - +from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model class DummyModel(SingleResidual): """ @@ -97,7 +96,6 @@ def forward(self, inputs): return x - class TestQuantSim: """Tests for QuantizationSimModel""" def test_insert_quantize_op_nodes(self): @@ -428,9 +426,23 @@ def callback(session, args): assert np.allclose(out2, out3) + def test_model_with_constants(self): model = multi_input_with_constant_model() sim = QuantizationSimModel(model) assert sim.qc_quantize_op_dict['13'].enabled == True assert sim.qc_quantize_op_dict['7'].enabled == True + + + def test_multiple_output_quantsim(self): + model = multi_output_model() + sample_input = np.random.rand(128, 3, 32, 32).astype(np.float32) + + sim = QuantizationSimModel(model=model, + quant_scheme=QuantScheme.post_training_tf_enhanced, + default_activation_bw=8, + default_param_bw=8) + sim.session.run(None, {'input': sample_input}) + +