From 3a4e6a3d0e767d694096ff0f7722baac8632c95e Mon Sep 17 00:00:00 2001 From: Harshita Mangal <63610745+quic-mangal@users.noreply.github.com> Date: Tue, 10 Oct 2023 11:23:46 -0700 Subject: [PATCH] Add export logic for adaround encodings (#2499) Signed-off-by: Harshita Mangal --- Docs/api_docs/onnx_adaround.rst | 47 +++++ Docs/api_docs/onnx_quantization.rst | 1 + Docs/onnx_code_examples/adaround.py | 68 +++++++ NightlyTests/onnx/test_adaround.py | 168 ++++++++++++++++++ .../aimet_onnx/adaround/adaround_weight.py | 37 ++-- .../src/python/aimet_onnx/qc_quantize_op.py | 61 ++++--- .../onnx/src/python/aimet_onnx/quantsim.py | 39 +++- .../onnx/test/python/test_adaround_weight.py | 9 +- 8 files changed, 387 insertions(+), 43 deletions(-) create mode 100644 Docs/api_docs/onnx_adaround.rst create mode 100644 Docs/onnx_code_examples/adaround.py create mode 100644 NightlyTests/onnx/test_adaround.py diff --git a/Docs/api_docs/onnx_adaround.rst b/Docs/api_docs/onnx_adaround.rst new file mode 100644 index 00000000000..f91ed98c688 --- /dev/null +++ b/Docs/api_docs/onnx_adaround.rst @@ -0,0 +1,47 @@ +:orphan: + +.. _api-onnx-adaround: + +================================== +AIMET ONNX AdaRound API +================================== + +User Guide Link +================ +To learn more about this technique, please see :ref:`AdaRound` + +Top-level API +============= +.. autofunction:: aimet_onnx.adaround.adaround_weight.Adaround.apply_adaround + + +Adaround Parameters +=================== +.. autoclass:: aimet_onnx.adaround.adaround_weight.AdaroundParameters + :members: + + +Code Example - Adaptive Rounding (AdaRound) +=========================================== + +This example shows how to use AIMET to perform Adaptive Rounding (AdaRound). + +**Required imports** + +.. literalinclude:: ../onnx_code_examples/adaround.py + :language: python + :lines: 42-43 + +**User should write this function to pass calibration data** + + +.. literalinclude:: ../onnx_code_examples/adaround.py + :language: python + :pyobject: pass_calibration_data + + +**Apply Adaround** + +.. literalinclude:: ../onnx_code_examples/adaround.py + :language: python + :pyobject: apply_adaround_example diff --git a/Docs/api_docs/onnx_quantization.rst b/Docs/api_docs/onnx_quantization.rst index b3908c8e9d2..8ad83340230 100644 --- a/Docs/api_docs/onnx_quantization.rst +++ b/Docs/api_docs/onnx_quantization.rst @@ -5,3 +5,4 @@ AIMET ONNX Quantization APIs AIMET Quantization for ONNX Models provides the following functionality. - :ref:`Quantization Simulation API`: Allows ability to simulate inference on quantized hardware - :ref:`Cross-Layer Equalization API`: Post-training quantization technique to equalize layer parameters + - :ref:`Adaround API`: Post-training quantization technique to optimize rounding of weight tensors diff --git a/Docs/onnx_code_examples/adaround.py b/Docs/onnx_code_examples/adaround.py new file mode 100644 index 00000000000..f453c8ef846 --- /dev/null +++ b/Docs/onnx_code_examples/adaround.py @@ -0,0 +1,68 @@ +# /usr/bin/env python3.6 +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +# pylint: skip-file + +""" AdaRound code example to be used for documentation generation. """ + +from aimet_onnx.adaround.adaround_weight import AdaroundParameters, Adaround +from aimet_onnx.quantsim import QuantizationSimModel + + +def pass_calibration_data(model): + """ + The User of the QuantizationSimModel API is expected to write this function based on their data set. + This is not a working function and is provided only as a guideline. + + :param model: + """ + + +def apply_adaround_example(model, dataloader): + """ + Example code to run adaround + + """ + params = AdaroundParameters(data_loader=dataloader, num_batches=1, default_num_iterations=5, + forward_fn=pass_calibration_data, + forward_pass_callback_args=None) + ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') + + sim = QuantizationSimModel(ada_rounded_model, + default_param_bw=8, + default_activation_bw=8, use_cuda=True) + sim.set_and_freeze_param_encodings('./dummy.encodings') diff --git a/NightlyTests/onnx/test_adaround.py b/NightlyTests/onnx/test_adaround.py new file mode 100644 index 00000000000..d9afdcc767a --- /dev/null +++ b/NightlyTests/onnx/test_adaround.py @@ -0,0 +1,168 @@ +# /usr/bin/env python3.8 +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +import os +from packaging import version +import json +import numpy as np +import pytest +import torch +from onnx import load_model +from onnxruntime.quantization.onnx_quantizer import ONNXModel +from torchvision import models + +from aimet_onnx.utils import make_dummy_input +from aimet_common.defs import QuantScheme +from aimet_onnx.quantsim import QuantizationSimModel +from torch_utils import get_cifar10_data_loaders, train_cifar10 +from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession + +from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters + +WORKING_DIR = '/tmp/quantsim' + +image_size = 32 +batch_size = 64 +num_workers = 4 + + +def model_eval_onnx(session, val_loader): + """ + :param model: model to be evaluated + :param early_stopping_iterations: if None, data loader will iterate over entire validation data + :return: top_1_accuracy on validation data + """ + + corr = 0 + total = 0 + for (i, batch) in enumerate(val_loader): + x, y = batch[0].numpy(), batch[1].numpy() + in_tensor = {'input': x} + out = session.run(None, in_tensor)[0] + corr += np.sum(np.argmax(out, axis=1) == y) + total += x.shape[0] + print(f'Accuracy: {corr / total}') + return corr / total + + +class TestAdaroundAcceptance: + """ Acceptance test for AIMET ONNX """ + @pytest.mark.cuda + def test_adaround(self): + if version.parse(torch.__version__) >= version.parse("1.13"): + np.random.seed(0) + torch.manual_seed(0) + + model = get_model() + + data_loader = dataloader() + dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + sess = build_session(model) + out_before_ada = sess.run(None, dummy_input) + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + session.run(None, in_tensor) + + params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback, + forward_pass_callback_args=None) + ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') + sess = build_session(ada_rounded_model) + out_after_ada = sess.run(None, dummy_input) + assert not np.array_equal(out_before_ada[0], out_after_ada[0]) + + with open('./dummy.encodings') as json_file: + encoding_data = json.load(json_file) + + sim = QuantizationSimModel(ada_rounded_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8, + default_activation_bw=8, use_cuda=True) + sim.set_and_freeze_param_encodings('./dummy.encodings') + sim.compute_encodings(callback, None) + assert sim.qc_quantize_op_dict['fc.weight'].encodings[0].delta == encoding_data['fc.weight'][0]['scale'] + +def get_model(): + model = models.resnet18(pretrained=False, num_classes=10) + if torch.cuda.is_available(): + device = torch.device('cuda:0') + model.to(device) + + torch.onnx.export(model, torch.rand(batch_size, 3, 32, 32).cuda(), './resnet18.onnx', + training=torch.onnx.TrainingMode.EVAL, + input_names=['input'], output_names=['output'], + dynamic_axes={ + 'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'}, + } + ) + + onnx_model = ONNXModel(load_model('./resnet18.onnx')) + return onnx_model + +def dataloader(): + class DataLoader: + """ + Example of a Dataloader which can be used for running AMPv2 + """ + def __init__(self, batch_size: int): + """ + :param batch_size: batch size for data loader + """ + self.batch_size = batch_size + + def __iter__(self): + """Iterates over dataset""" + dummy_input = np.random.rand(1, 3, 32, 32).astype(np.float32) + yield dummy_input + + def __len__(self): + return 4 + + dummy_dataloader = DataLoader(batch_size=2) + return dummy_dataloader + +def build_session(model): + """ + Build and return onnxruntime inference session + :param providers: providers to execute onnxruntime + """ + sess_options = SessionOptions() + sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + session = InferenceSession( + path_or_bytes=model.model.SerializeToString(), + sess_options=sess_options, + providers=['CPUExecutionProvider'], + ) + return session \ No newline at end of file diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py index 1968e5b4045..856318a4e05 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/adaround/adaround_weight.py @@ -41,7 +41,7 @@ import copy import os import shutil -import contextlib +import json from typing import Tuple, Dict, List, Callable from onnx import onnx_pb from tqdm import tqdm @@ -301,27 +301,28 @@ def _create_param_to_tensor_quantizer_dict(quant_sim: QuantizationSimModel) -> D def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: QuantizationSimModel): """ Save Adadrounded module's parameter encodings to JSON file + :param path: path where to store param encodings :param filename_prefix: filename to store exported weight encodings in JSON format :param quant_sim: QunatSim that contains the model and Adaround tensor quantizers """ - - @classmethod - def _update_param_encodings_dict(cls, quant_module, name: str, param_encodings: Dict): - """ - Add module's weight parameter encodings to dictionary to be used for exporting encodings - :param quant_module: quant module - :param name: name of module - :param param_encodings: Dictionary of param encodings - """ - - @staticmethod - def _create_encodings_dict_for_quantizer(quantizer) -> List[Dict]: - """ - Return encodings for given qunatizer - :param quantizer: Tensor quantizer associated with module's param - :return: Dictionary containing encodings - """ + # pylint: disable=protected-access + def update_encoding_dict_entry(encoding_dict: Dict, op_name: str): + qc_quantize_op = quant_sim.qc_quantize_op_dict[op_name] + encoding_dict[op_name] = [] + for encoding in qc_quantize_op.encodings: + encoding_dict[op_name].append(QuantizationSimModel._create_encoding_dict(encoding, qc_quantize_op)) + + param_encodings = {} + for name in quant_sim.param_names: + if quant_sim.qc_quantize_op_dict[name].enabled: + update_encoding_dict_entry(param_encodings, name) + + # export encodings to JSON file + os.makedirs(os.path.abspath(path), exist_ok=True) + encoding_file_path = os.path.join(path, filename_prefix + '.encodings') + with open(encoding_file_path, 'w') as encoding_fp: + json.dump(param_encodings, encoding_fp, sort_keys=True, indent=4) @staticmethod def _override_param_bitwidth(quant_sim: QuantizationSimModel, diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 8804344322c..558809baa93 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -87,6 +87,7 @@ def __init__(self, quant_info: libquant_info.QcQuantizeInfo, self.quant_info = quant_info self.quant_scheme = quant_scheme self.rounding_mode = rounding_mode + self._is_encoding_frozen = False self._tensor_quantizer = None self.set_tensor_quantizer(self._build_tensor_quantizer()) self.op_mode = op_mode @@ -97,6 +98,14 @@ def __init__(self, quant_info: libquant_info.QcQuantizeInfo, self._data_type = QuantizationDataType.int self.tensor_quantizer_params = tensor_quantizer_params + def is_encoding_frozen(self) -> bool: + """ Returns is_encoding_frozen var """ + return self._is_encoding_frozen + + def freeze_encodings(self): + """ Sets encodings to frozen """ + self._is_encoding_frozen = True + def enable_per_channel_quantization(self): """ Enables per channel quantization for qc_quantize_op @@ -133,10 +142,11 @@ def data_type(self, data_type: QuantizationDataType): :param data_type: Quantization data type """ - self._data_type = data_type - self.quant_info.isIntDataType = False - if data_type == QuantizationDataType.int: - self.quant_info.isIntDataType = True + if not self._is_encoding_frozen: + self._data_type = data_type + self.quant_info.isIntDataType = False + if data_type == QuantizationDataType.int: + self.quant_info.isIntDataType = True def _build_tensor_quantizer(self): return libpymo.TensorQuantizer(MAP_QUANT_SCHEME_TO_PYMO[self.quant_scheme], @@ -181,7 +191,8 @@ def use_symmetric_encodings(self, use_symmetric_encodings: bool): Sets the useSymmetricEncoding attribute of the nodes QcQuantizeInfo object :param use_symmetric_encodings: True if the node is to use symmetric encodings """ - self.quant_info.useSymmetricEncoding = use_symmetric_encodings + if not self._is_encoding_frozen: + self.quant_info.useSymmetricEncoding = use_symmetric_encodings @property def use_strict_symmetric(self) -> bool: @@ -279,20 +290,22 @@ def reset_encoding_stats(self): """ reset the stats of tensor quantizer """ - encodings = [] - for tensor_quantizer in self._tensor_quantizer: - encoding = libpymo.TfEncoding() - encoding.bw = self.bitwidth - encodings.append(encoding) - tensor_quantizer.resetEncodingStats() - self.encodings = encodings + if not self._is_encoding_frozen: + encodings = [] + for tensor_quantizer in self._tensor_quantizer: + encoding = libpymo.TfEncoding() + encoding.bw = self.bitwidth + encodings.append(encoding) + tensor_quantizer.resetEncodingStats() + self.encodings = encodings def set_bitwidth(self, bitwidth: int): """ Set bitwidth for quantization """ - self.bitwidth = bitwidth - self.reset_encoding_stats() + if not self._is_encoding_frozen: + self.bitwidth = bitwidth + self.reset_encoding_stats() def set_quant_scheme(self, quant_scheme: QuantScheme): """ @@ -309,12 +322,14 @@ def compute_encodings(self) -> libpymo.TfEncoding: """ Compute and return encodings of each tensor quantizer """ - if self.enabled: - encodings = [] - for tensor_quantizer in self._tensor_quantizer: - encodings.append(tensor_quantizer.computeEncoding(self.bitwidth, self.use_symmetric_encodings)) - self.encodings = encodings - else: - encodings = None - - return encodings + if not self._is_encoding_frozen: + if self.enabled: + encodings = [] + for tensor_quantizer in self._tensor_quantizer: + encodings.append(tensor_quantizer.computeEncoding(self.bitwidth, self.use_symmetric_encodings)) + self.encodings = encodings + else: + encodings = None + + return encodings + return None diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 57b2bc0e90c..c232e7d84c7 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -394,10 +394,12 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): qc_op.op_mode = OpMode.updateStats else: qc_op.op_mode = OpMode.oneShotQuantizeDequantize + if qc_op.is_encoding_frozen(): + qc_op.op_mode = OpMode.quantizeDequantize forward_pass_callback(self.session, forward_pass_callback_args) for op_name, qc_op in self.qc_quantize_op_dict.items(): - if qc_op.data_type == QuantizationDataType.int: + if qc_op.data_type == QuantizationDataType.int and not qc_op.is_encoding_frozen(): qc_op.compute_encodings() qc_op.op_mode = OpMode.quantizeDequantize @@ -482,6 +484,41 @@ def export(self, path: str, filename_prefix: str): self.remove_quantization_nodes() self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx') + def set_and_freeze_param_encodings(self, encoding_path: str): + """ + Set and freeze parameter encodings from encodings JSON file + + :param encoding_path: path from where to load parameter encodings file + """ + + def _create_libpymo_encodings(encoding): + libpymo_encodings = [] + for enc_val in encoding: + enc = libpymo.TfEncoding() + enc.bw, enc.delta, enc.max, enc.min, enc.offset = enc_val['bitwidth'], enc_val['scale'], enc_val['max'], \ + enc_val['min'], enc_val['offset'] + libpymo_encodings.append(enc) + return libpymo_encodings + + # Load encodings file + with open(encoding_path) as json_file: + encodings = json.load(json_file) + + for quantizer_name in encodings: + if quantizer_name in self.qc_quantize_op_dict: + libpymo_encodings = _create_libpymo_encodings(encodings[quantizer_name]) + self.qc_quantize_op_dict[quantizer_name].load_encodings(libpymo_encodings) + self.qc_quantize_op_dict[quantizer_name].bitwidth = encodings[quantizer_name][0]['bitwidth'] + dtype = QuantizationDataType.float + if encodings[quantizer_name][0]['dtype'] == 'int': + dtype = QuantizationDataType.int + self.qc_quantize_op_dict[quantizer_name].data_type = dtype + is_symmetric = False + if encodings[quantizer_name][0]['is_symmetric'] == 'True': + is_symmetric = True + self.qc_quantize_op_dict[quantizer_name].use_symmetric_encodings = is_symmetric + self.qc_quantize_op_dict[quantizer_name].freeze_encodings() + def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str): """ diff --git a/TrainingExtensions/onnx/test/python/test_adaround_weight.py b/TrainingExtensions/onnx/test/python/test_adaround_weight.py index f3db55a5231..2d55d83da64 100644 --- a/TrainingExtensions/onnx/test/python/test_adaround_weight.py +++ b/TrainingExtensions/onnx/test/python/test_adaround_weight.py @@ -36,7 +36,8 @@ # @@-COPYRIGHT-END-@@ # ============================================================================= -""" Unit tests for Adaround Activation Sampler """ +""" Unit tests for Adaround Weights """ +import json from packaging import version import numpy as np import torch @@ -70,6 +71,12 @@ def callback(session, args): out_after_ada = sess.run(None, dummy_input) assert not np.array_equal(out_before_ada[0], out_after_ada[0]) + with open('./dummy.encodings') as json_file: + encoding_data = json.load(json_file) + + param_keys = list(encoding_data.keys()) + assert 'onnx::Conv_43' in param_keys + def dataloader(): class DataLoader: