Skip to content

Commit

Permalink
Resolve test failures
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu committed Dec 20, 2024
1 parent a0f2162 commit 56b1daf
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import OnnxExportApiArgs
from aimet_torch.v1.qc_quantize_op import QcQuantizeWrapper
from aimet_common.defs import QuantScheme
from aimet_torch.utils import get_layer_by_name

from ..models_.models_to_test import (
Expand Down Expand Up @@ -371,8 +372,8 @@ def test_json_interchangeable(self):
model = resnet18().eval()
dummy_input = torch.randn(1, 3, 224, 224)

sim_v1 = QuantizationSimModelV1(model, dummy_input)
sim_v2 = QuantizationSimModel(model, dummy_input)
sim_v1 = QuantizationSimModelV1(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)
sim_v2 = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)

sim_v1.compute_encodings(lambda model, _: model(dummy_input), None)
sim_v2.compute_encodings(lambda model, _: model(dummy_input), None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def test_save_and_load_gbbq(self):

assert qsim.model.fc.param_quantizers['weight'].get_max()[0][0] == old_max
out3 = qsim.model(dummy_input)
assert torch.equal(out1, out3)
assert torch.allclose(out1, out3)

qsim.model.fc.weight = torch.nn.Parameter(torch.randn(old_weight.shape))
qsim.compute_encodings(lambda m, _: m(dummy_input_2), None)
Expand All @@ -793,7 +793,7 @@ def test_save_and_load_gbbq(self):
qsim.load_encodings(os.path.join(temp_dir, 'exported_encodings_torch.encodings'))

out4 = qsim.model(dummy_input)
assert torch.equal(out1, out4)
assert torch.allclose(out1, out4)


def test_quantsim_with_unused_modules(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_export_per_layer_stats_histogram(self):
input_shape = (1, 3, 32, 32)
dummy_input = torch.randn(*input_shape)
model = TinyModel().eval()
sim = QuantizationSimModel(model, dummy_input)
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced)
sim.compute_encodings(evaluate, dummy_input)
forward_pass_callback = CallbackFunc(calibrate, dummy_input)
eval_callback = CallbackFunc(evaluate, dummy_input)
Expand Down Expand Up @@ -276,7 +276,8 @@ def test_export_per_layer_stats_histogram_per_channel(self):
input_shape = (1, 3, 32, 32)
dummy_input = torch.randn(*input_shape)
model = TinyModel().eval()
sim = QuantizationSimModel(model, dummy_input, config_file=os.path.join(tmp_dir, "quantsim_config.json"))
sim = QuantizationSimModel(model, dummy_input, config_file=os.path.join(tmp_dir, "quantsim_config.json"),
quant_scheme=QuantScheme.post_training_tf_enhanced)
sim.compute_encodings(evaluate, dummy_input)
forward_pass_callback = CallbackFunc(calibrate, dummy_input)
eval_callback = CallbackFunc(evaluate, dummy_input)
Expand Down

0 comments on commit 56b1daf

Please sign in to comment.