Skip to content

Commit

Permalink
Change default quant scheme in aimet_torch.v2 QuantizationSimModel (#…
Browse files Browse the repository at this point in the history
…3687)

Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Dec 20, 2024
1 parent a583285 commit 77a27ad
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def __init__(self, # pylint: disable=too-many-arguments, too-many-locals, too-ma
if not quant_scheme:
old_default = QuantScheme.post_training_tf_enhanced
new_default = QuantScheme.training_range_learning_with_tf_init
msg = _red(f"The default value of 'quant_scheme' will change from '{old_default}' "
f"to '{new_default}' in the later versions. "
"If you wish to maintain the legacy behavior in the future, "
msg = _red(f"The default value of 'quant_scheme' has changed from '{old_default}' "
f"to '{new_default}' since aimet-torch==2.0.0. "
"If you wish to maintain the legacy default behavior, "
f"please explicitly pass 'quant_scheme={old_default}'")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
quant_scheme = old_default
quant_scheme = new_default

if rounding_mode:
if rounding_mode == 'nearest':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import OnnxExportApiArgs
from aimet_torch.v1.qc_quantize_op import QcQuantizeWrapper
from aimet_common.defs import QuantScheme
from aimet_torch.utils import get_layer_by_name

from ..models_.models_to_test import (
Expand Down Expand Up @@ -371,8 +372,8 @@ def test_json_interchangeable(self):
model = resnet18().eval()
dummy_input = torch.randn(1, 3, 224, 224)

sim_v1 = QuantizationSimModelV1(model, dummy_input)
sim_v2 = QuantizationSimModel(model, dummy_input)
sim_v1 = QuantizationSimModelV1(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)
sim_v2 = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)

sim_v1.compute_encodings(lambda model, _: model(dummy_input), None)
sim_v2.compute_encodings(lambda model, _: model(dummy_input), None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def test_save_and_load_gbbq(self):

assert qsim.model.fc.param_quantizers['weight'].get_max()[0][0] == old_max
out3 = qsim.model(dummy_input)
assert torch.equal(out1, out3)
assert torch.allclose(out1, out3)

qsim.model.fc.weight = torch.nn.Parameter(torch.randn(old_weight.shape))
qsim.compute_encodings(lambda m, _: m(dummy_input_2), None)
Expand All @@ -793,7 +793,7 @@ def test_save_and_load_gbbq(self):
qsim.load_encodings(os.path.join(temp_dir, 'exported_encodings_torch.encodings'))

out4 = qsim.model(dummy_input)
assert torch.equal(out1, out4)
assert torch.allclose(out1, out4)


def test_quantsim_with_unused_modules(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_export_per_layer_stats_histogram(self):
input_shape = (1, 3, 32, 32)
dummy_input = torch.randn(*input_shape)
model = TinyModel().eval()
sim = QuantizationSimModel(model, dummy_input)
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)
sim.compute_encodings(evaluate, dummy_input)
forward_pass_callback = CallbackFunc(calibrate, dummy_input)
eval_callback = CallbackFunc(evaluate, dummy_input)
Expand Down Expand Up @@ -276,7 +276,8 @@ def test_export_per_layer_stats_histogram_per_channel(self):
input_shape = (1, 3, 32, 32)
dummy_input = torch.randn(*input_shape)
model = TinyModel().eval()
sim = QuantizationSimModel(model, dummy_input, config_file=os.path.join(tmp_dir, "quantsim_config.json"))
sim = QuantizationSimModel(model, dummy_input, config_file=os.path.join(tmp_dir, "quantsim_config.json"),
quant_scheme=QuantScheme.post_training_tf_enhanced)
sim.compute_encodings(evaluate, dummy_input)
forward_pass_callback = CallbackFunc(calibrate, dummy_input)
eval_callback = CallbackFunc(evaluate, dummy_input)
Expand Down

0 comments on commit 77a27ad

Please sign in to comment.