Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to compute encodings for multiple QuantSims #2507

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 79 additions & 25 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@

""" Implementation for simulating models running on Quantized hardware """
# pylint: disable=too-many-lines
import contextlib
import os
import io
import copy
import pickle
from typing import Tuple, List, Union, Dict, Callable, Optional
from typing import Tuple, List, Union, Dict, Callable, Optional, Any
from collections.abc import Iterable
import json
import torch
Expand Down Expand Up @@ -285,24 +286,15 @@ def print_quantizer_state(stream, quantizer, prefix_string):

return stream.getvalue()

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
@staticmethod
def prepare_sim_for_compute_encodings(sim: 'QuantizationSimModel'):
"""
Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for
Range Learning

:param forward_pass_callback: A callback function that simply runs forward passes on the model. This callback
function should use representative data for the forward pass, so the calculated encodings work for all
data samples. This callback internally chooses the number of data samples it wants to use for calculating
encodings.
:param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
:return: None
Prepare QuantSim for compute encodings. Resets encodings for each quantizable layer and sets mode to Analysis.

:param sim: QuantSim to prepare
"""

quantized_layers = self._get_qc_quantized_layers(self.model)
# pylint: disable=protected-access
quantized_layers = sim._get_qc_quantized_layers(sim.model)

for _, layer in quantized_layers:
# Clear stats and encodings if they are present
Expand All @@ -313,29 +305,58 @@ def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):

for _, layer in quantized_layers:
# call only when quant scheme is percentile
if self._quant_scheme == QuantScheme.post_training_percentile:
layer.set_percentile_value(self._percentile_value)
if sim._quant_scheme == QuantScheme.post_training_percentile:
layer.set_percentile_value(sim._percentile_value)

# Run forward iterations so we can collect statistics to compute the appropriate encodings
with utils.in_eval_mode(self.model), torch.no_grad():
_ = forward_pass_callback(self.model, forward_pass_callback_args)
@staticmethod
def compute_layer_encodings_for_sim(sim: 'QuantizationSimModel'):
"""
Compute encodings for each quantizable layer in sim after forward pass has been called.

:param sim: QuantSim to compute encodings for
"""
# pylint: disable=protected-access
quantized_layers = sim._get_qc_quantized_layers(sim.model)
# Get the computed per-layer encodings and log them
for name, layer in quantized_layers:
layer.compute_encoding()

# Before we return we set the mode to active - meaning ready for quantize/de-quantize
# for layers with valid_encoding, otherwise we set to pass through
if isinstance(layer, QcQuantizeRecurrent):
self.set_mode_for_recurrent_module(layer, name)

sim.set_mode_for_recurrent_module(layer, name)
else:
# By default we want to set the Quantization wrappers to ACTIVE mode
layer.set_mode(QcQuantizeOpMode.ACTIVE)

self.replace_wrappers_for_quantize_dequantize()
sim.replace_wrappers_for_quantize_dequantize()

self._clamp_transformer_attention_mask_encoding()
sim._clamp_transformer_attention_mask_encoding()

def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
"""
Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for
Range Learning

:param forward_pass_callback: A callback function that simply runs forward passes on the model. This callback
function should use representative data for the forward pass, so the calculated encodings work for all
data samples. This callback internally chooses the number of data samples it wants to use for calculating
encodings.
:param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
:return: None

"""

QuantizationSimModel.prepare_sim_for_compute_encodings(self)

# Run forward iterations so we can collect statistics to compute the appropriate encodings
with utils.in_eval_mode(self.model), torch.no_grad():
_ = forward_pass_callback(self.model, forward_pass_callback_args)

QuantizationSimModel.compute_layer_encodings_for_sim(self)

@classmethod
def set_mode_for_recurrent_module(cls, layer: QcQuantizeRecurrent, name: str):
Expand Down Expand Up @@ -1915,3 +1936,36 @@ def has_valid_encodings(qc_quantize_op: Union[QcQuantizeWrapper, QcQuantizeRecur
return True

return False


def compute_encodings_for_sims(sim_list: List[QuantizationSimModel], forward_pass_callback: Callable,
forward_pass_callback_args: Any):
"""
Compute encodings for a list of QuantSims.

:param sim_list: List of QuantSims to compute encodings for.
:param forward_pass_callback: A callback function that simply runs forward passes on the models. This callback
function should use representative data for the forward pass, so the calculated encodings work for all
data samples. This callback internally chooses the number of data samples it wants to use for calculating
encodings.
The callback expects exactly two inputs:
- List of models which are involved in the forward pass. The models are taken directly from calling
sim.model for each sim in sim_list, passed in the same order in which the sims appear in sim_list.
- Forward pass callback args
:param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
"""
ctx_managers = [torch.no_grad()]
for sim in sim_list:
ctx_managers.append(utils.in_eval_mode(sim.model))
QuantizationSimModel.prepare_sim_for_compute_encodings(sim)

with contextlib.ExitStack() as stack:
for mgr in ctx_managers:
stack.enter_context(mgr)
_ = forward_pass_callback([sim.model for sim in sim_list], forward_pass_callback_args)

for sim in sim_list:
QuantizationSimModel.compute_layer_encodings_for_sim(sim)
44 changes: 43 additions & 1 deletion TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
StaticGridQuantWrapper, QcQuantizeOpMode, LearnedGridQuantWrapper, enable_recompute, no_recompute
from aimet_torch.qc_quantize_recurrent import QcQuantizeRecurrent
from aimet_torch.quantsim import QuantizationSimModel, check_accumulator_overflow, load_encodings_to_sim, \
has_valid_encodings
has_valid_encodings, compute_encodings_for_sims
from aimet_torch.quantsim_straight_through_grad import compute_dloss_by_dx

from models import test_models
Expand Down Expand Up @@ -2764,6 +2764,48 @@ def test_save_encodings_to_json(self):
assert len(encodings['activation_encodings']) == 14
assert len(encodings['param_encodings']) == 5

def test_compute_encodings_for_multiple_sims(self):
class SecondModel(torch.nn.Module):
def __init__(self, const_inp_shape):
super(SecondModel, self).__init__()
self.add = elementwise_ops.Add()
self.sub = elementwise_ops.Subtract()
self.batchnorm = torch.nn.BatchNorm1d(10)
self.const_tensor = torch.randn(const_inp_shape)

def forward(self, inp, inp2):
x = self.add(inp, self.const_tensor)
x = self.batchnorm(x)
x = self.sub(x, inp2)
return x

model = ModelWithTwoInputsOneToAdd()
model.eval()
dummy_input = (torch.rand(32, 1, 100, 100), torch.rand(32, 10, 22, 22))
model_1_out = model(*dummy_input)
model_2 = SecondModel(model_1_out.shape)
model_2.eval()
dummy_input_2 = torch.randn(model_1_out.shape)
sim1 = QuantizationSimModel(model, dummy_input)
sim2 = QuantizationSimModel(model_2, (dummy_input_2, dummy_input_2))

def forward_pass_callback(model_list, _):
x = model_list[0](*dummy_input)
x = model_list[1](x, dummy_input_2)
return x

sim2.model.train()
running_mean = sim2.model.batchnorm.running_mean.clone().detach()
compute_encodings_for_sims([sim1, sim2], forward_pass_callback, None)

# Check that even though sim2 was in training mode prior to compute encodings, it was placed in eval mode
# during compute encodings, and that it was placed back to training mode afterwards.
assert sim2.model.training
assert torch.equal(running_mean, sim2.model.batchnorm.running_mean)
assert sim1.model.conv1_a.output_quantizers[0].encoding is not None
assert sim2.model.add.input_quantizers[0].encoding is not None
assert sim2.model.add.input_quantizers[1].encoding is not None


class TestQuantizationSimLearnedGrid:

Expand Down
Loading