Skip to content

Commit

Permalink
Squashed and rebased changes allowing ONNX QuantSim to support models
Browse files Browse the repository at this point in the history
with ops that have multiple consumers.

Signed-off-by: Ashvin Kumar <[email protected]>
  • Loading branch information
quic-ashvkuma committed Oct 18, 2023
1 parent d46b53c commit 595f6b6
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,25 @@ def _link_branch_op_to_multiple_ops(self, branch_op: Op, product_list: list,
# 4: Remove product from original op to child in child's inputs
# 5: Remove product from self._products
for product in product_list:
assert len(product.consumers) == 1
# item 1
branch_op_product.add_consumer(product.consumers[0])
# item 2
assert len(product.tensor_dict.keys()) == 1
branch_op_product.tensor_dict[product.consumers[0]] = product.tensor_dict[product.consumers[0]]
# items 3 and 4
# replace the old product with the new branch product, in the same index as the old product
index = product.consumers[0].inputs.index(product)
product.consumers[0].inputs[index] = branch_op_product
# item 5
del self._products[product.name]
assert len(product.consumers) <= 1

if len(product.consumers) == 1:
# item 1
branch_op_product.add_consumer(product.consumers[0])
# item 2
assert len(product.tensor_dict.keys()) == 1
branch_op_product.tensor_dict[product.consumers[0]] = product.tensor_dict[product.consumers[0]]
# items 3 and 4
# replace the old product with the new branch product, in the same index as the old product
index = product.consumers[0].inputs.index(product)
product.consumers[0].inputs[index] = branch_op_product
# item 5
del self._products[product.name]
else:
for output in self.model.graph.output:
if output.name in product.name:
del self._products[product.name]
self._create_link_for_output_product(output.name, branch_op.name)

self._products[branch_op_product.name] = branch_op_product

Expand Down
58 changes: 58 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,19 @@ def multi_input_model():
model = ONNXModel(load_model('./model_multi_input.onnx'))
return model

def multi_output_model():
model = MultipleOutputModel()
sample_input = np.random.rand(128, 3, 32, 32).astype(np.float32)

onnx_filename = "/tmp/dummy_model_multiple_outputs.onnx"
input_names = ["input"]
output_names = ["output_mul", "output_add"]
torch.onnx.export(model, torch.as_tensor(sample_input), onnx_filename, verbose=True, input_names=input_names,
output_names=output_names)

model = ONNXModel(load_model(onnx_filename))
return model

def transposed_conv_model():
x = torch.randn(10, 10, 4, 4, requires_grad=True)
model = TransposedConvModel()
Expand Down Expand Up @@ -1417,6 +1430,51 @@ def forward(self, x):

return x

class MultipleOutputModel(SingleResidual):
"""
Model
"""
def __init__(self):
super().__init__()
# change padding size to 0, onnxruntime only support input size is the factor of output size for pooling
self.conv4 = torch.nn.Conv2d(32, 8, kernel_size=2, stride=2, padding=0, bias=True)
self.fc2 = torch.nn.Linear(10, 3)
# remove bn layer for currently not supporting non-4 dim param tensors
del self.bn1
del self.bn2

def forward(self, inputs):
x = self.conv1(inputs)
# TODO
# remove bn layer for currently not supporting non-4 dim param tensors
# x = self.bn1(x)
x = self.relu1(x)
x = self.maxpool(x)

# Save the output of MaxPool as residual.
residual = x

x = self.conv2(x)
# TODO
# remove bn layer for currently not supporting non-4 dim param tensors
# x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)

# Add the residual
# AdaptiveAvgPool2d is used to get the desired dimension before adding.
residual = self.conv4(residual)
residual = self.ada(residual)
x += residual
x = self.relu3(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
y = self.fc2(x)

return x, y


def _convert_to_onnx_no_fold(model: torch.nn.Module, dummy_input, filename='./temp_model.onnx'):
torch.onnx.export(model.eval(),
Expand Down
18 changes: 15 additions & 3 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@
from aimet_onnx.qc_quantize_op import OpMode
from aimet_onnx.utils import make_dummy_input
from models.models_for_tests import SingleResidual
from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model

from models.models_for_tests import build_dummy_model, single_residual_model, BNAfterConv, multi_input_with_constant_model , multi_output_model

class DummyModel(SingleResidual):
"""
Expand Down Expand Up @@ -97,7 +96,6 @@ def forward(self, inputs):

return x


class TestQuantSim:
"""Tests for QuantizationSimModel"""
def test_insert_quantize_op_nodes(self):
Expand Down Expand Up @@ -428,9 +426,23 @@ def callback(session, args):

assert np.allclose(out2, out3)


def test_model_with_constants(self):
model = multi_input_with_constant_model()

sim = QuantizationSimModel(model)
assert sim.qc_quantize_op_dict['13'].enabled == True
assert sim.qc_quantize_op_dict['7'].enabled == True


def test_multiple_output_quantsim(self):
model = multi_output_model()
sample_input = np.random.rand(128, 3, 32, 32).astype(np.float32)

sim = QuantizationSimModel(model=model,
quant_scheme=QuantScheme.post_training_tf_enhanced,
default_activation_bw=8,
default_param_bw=8)
sim.session.run(None, {'input': sample_input})


0 comments on commit 595f6b6

Please sign in to comment.