-
Notifications
You must be signed in to change notification settings - Fork 393
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add export logic for adaround encodings (#2499)
Signed-off-by: Harshita Mangal <[email protected]>
- Loading branch information
1 parent
60fbf1f
commit 3a4e6a3
Showing
8 changed files
with
387 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
:orphan: | ||
|
||
.. _api-onnx-adaround: | ||
|
||
================================== | ||
AIMET ONNX AdaRound API | ||
================================== | ||
|
||
User Guide Link | ||
================ | ||
To learn more about this technique, please see :ref:`AdaRound<ug-adaround>` | ||
|
||
Top-level API | ||
============= | ||
.. autofunction:: aimet_onnx.adaround.adaround_weight.Adaround.apply_adaround | ||
|
||
|
||
Adaround Parameters | ||
=================== | ||
.. autoclass:: aimet_onnx.adaround.adaround_weight.AdaroundParameters | ||
:members: | ||
|
||
|
||
Code Example - Adaptive Rounding (AdaRound) | ||
=========================================== | ||
|
||
This example shows how to use AIMET to perform Adaptive Rounding (AdaRound). | ||
|
||
**Required imports** | ||
|
||
.. literalinclude:: ../onnx_code_examples/adaround.py | ||
:language: python | ||
:lines: 42-43 | ||
|
||
**User should write this function to pass calibration data** | ||
|
||
|
||
.. literalinclude:: ../onnx_code_examples/adaround.py | ||
:language: python | ||
:pyobject: pass_calibration_data | ||
|
||
|
||
**Apply Adaround** | ||
|
||
.. literalinclude:: ../onnx_code_examples/adaround.py | ||
:language: python | ||
:pyobject: apply_adaround_example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# /usr/bin/env python3.6 | ||
# -*- mode: python -*- | ||
# ============================================================================= | ||
# @@-COPYRIGHT-START-@@ | ||
# | ||
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# 3. Neither the name of the copyright holder nor the names of its contributors | ||
# may be used to endorse or promote products derived from this software | ||
# without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# @@-COPYRIGHT-END-@@ | ||
# ============================================================================= | ||
# pylint: skip-file | ||
|
||
""" AdaRound code example to be used for documentation generation. """ | ||
|
||
from aimet_onnx.adaround.adaround_weight import AdaroundParameters, Adaround | ||
from aimet_onnx.quantsim import QuantizationSimModel | ||
|
||
|
||
def pass_calibration_data(model): | ||
""" | ||
The User of the QuantizationSimModel API is expected to write this function based on their data set. | ||
This is not a working function and is provided only as a guideline. | ||
:param model: | ||
""" | ||
|
||
|
||
def apply_adaround_example(model, dataloader): | ||
""" | ||
Example code to run adaround | ||
""" | ||
params = AdaroundParameters(data_loader=dataloader, num_batches=1, default_num_iterations=5, | ||
forward_fn=pass_calibration_data, | ||
forward_pass_callback_args=None) | ||
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') | ||
|
||
sim = QuantizationSimModel(ada_rounded_model, | ||
default_param_bw=8, | ||
default_activation_bw=8, use_cuda=True) | ||
sim.set_and_freeze_param_encodings('./dummy.encodings') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# /usr/bin/env python3.8 | ||
# -*- mode: python -*- | ||
# ============================================================================= | ||
# @@-COPYRIGHT-START-@@ | ||
# | ||
# Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# 3. Neither the name of the copyright holder nor the names of its contributors | ||
# may be used to endorse or promote products derived from this software | ||
# without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
# POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# @@-COPYRIGHT-END-@@ | ||
# ============================================================================= | ||
import os | ||
from packaging import version | ||
import json | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from onnx import load_model | ||
from onnxruntime.quantization.onnx_quantizer import ONNXModel | ||
from torchvision import models | ||
|
||
from aimet_onnx.utils import make_dummy_input | ||
from aimet_common.defs import QuantScheme | ||
from aimet_onnx.quantsim import QuantizationSimModel | ||
from torch_utils import get_cifar10_data_loaders, train_cifar10 | ||
from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession | ||
|
||
from aimet_onnx.adaround.adaround_weight import Adaround, AdaroundParameters | ||
|
||
WORKING_DIR = '/tmp/quantsim' | ||
|
||
image_size = 32 | ||
batch_size = 64 | ||
num_workers = 4 | ||
|
||
|
||
def model_eval_onnx(session, val_loader): | ||
""" | ||
:param model: model to be evaluated | ||
:param early_stopping_iterations: if None, data loader will iterate over entire validation data | ||
:return: top_1_accuracy on validation data | ||
""" | ||
|
||
corr = 0 | ||
total = 0 | ||
for (i, batch) in enumerate(val_loader): | ||
x, y = batch[0].numpy(), batch[1].numpy() | ||
in_tensor = {'input': x} | ||
out = session.run(None, in_tensor)[0] | ||
corr += np.sum(np.argmax(out, axis=1) == y) | ||
total += x.shape[0] | ||
print(f'Accuracy: {corr / total}') | ||
return corr / total | ||
|
||
|
||
class TestAdaroundAcceptance: | ||
""" Acceptance test for AIMET ONNX """ | ||
@pytest.mark.cuda | ||
def test_adaround(self): | ||
if version.parse(torch.__version__) >= version.parse("1.13"): | ||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
model = get_model() | ||
|
||
data_loader = dataloader() | ||
dummy_input = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} | ||
sess = build_session(model) | ||
out_before_ada = sess.run(None, dummy_input) | ||
def callback(session, args): | ||
in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} | ||
session.run(None, in_tensor) | ||
|
||
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=5, forward_fn=callback, | ||
forward_pass_callback_args=None) | ||
ada_rounded_model = Adaround.apply_adaround(model, params, './', 'dummy') | ||
sess = build_session(ada_rounded_model) | ||
out_after_ada = sess.run(None, dummy_input) | ||
assert not np.array_equal(out_before_ada[0], out_after_ada[0]) | ||
|
||
with open('./dummy.encodings') as json_file: | ||
encoding_data = json.load(json_file) | ||
|
||
sim = QuantizationSimModel(ada_rounded_model, dummy_input, quant_scheme=QuantScheme.post_training_tf, default_param_bw=8, | ||
default_activation_bw=8, use_cuda=True) | ||
sim.set_and_freeze_param_encodings('./dummy.encodings') | ||
sim.compute_encodings(callback, None) | ||
assert sim.qc_quantize_op_dict['fc.weight'].encodings[0].delta == encoding_data['fc.weight'][0]['scale'] | ||
|
||
def get_model(): | ||
model = models.resnet18(pretrained=False, num_classes=10) | ||
if torch.cuda.is_available(): | ||
device = torch.device('cuda:0') | ||
model.to(device) | ||
|
||
torch.onnx.export(model, torch.rand(batch_size, 3, 32, 32).cuda(), './resnet18.onnx', | ||
training=torch.onnx.TrainingMode.EVAL, | ||
input_names=['input'], output_names=['output'], | ||
dynamic_axes={ | ||
'input': {0: 'batch_size'}, | ||
'output': {0: 'batch_size'}, | ||
} | ||
) | ||
|
||
onnx_model = ONNXModel(load_model('./resnet18.onnx')) | ||
return onnx_model | ||
|
||
def dataloader(): | ||
class DataLoader: | ||
""" | ||
Example of a Dataloader which can be used for running AMPv2 | ||
""" | ||
def __init__(self, batch_size: int): | ||
""" | ||
:param batch_size: batch size for data loader | ||
""" | ||
self.batch_size = batch_size | ||
|
||
def __iter__(self): | ||
"""Iterates over dataset""" | ||
dummy_input = np.random.rand(1, 3, 32, 32).astype(np.float32) | ||
yield dummy_input | ||
|
||
def __len__(self): | ||
return 4 | ||
|
||
dummy_dataloader = DataLoader(batch_size=2) | ||
return dummy_dataloader | ||
|
||
def build_session(model): | ||
""" | ||
Build and return onnxruntime inference session | ||
:param providers: providers to execute onnxruntime | ||
""" | ||
sess_options = SessionOptions() | ||
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL | ||
session = InferenceSession( | ||
path_or_bytes=model.model.SerializeToString(), | ||
sess_options=sess_options, | ||
providers=['CPUExecutionProvider'], | ||
) | ||
return session |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.