Skip to content

Commit

Permalink
BREAKING_CHANGE: switch aimet_torch default API from v1 to v2 (#3689)
Browse files Browse the repository at this point in the history
Switched the default API of aimet_torch from v1 to v2.

From this commit onwards, any package or module named ``aimet_torch.<X>`` will point to ``aimet_torch.v2.<X>`` (if any).
For example,  ``aimet_torch.quantsim`` will now point to ``aimet_torch.v2.quantsim``, not ``aimet_torch.v1.quantsim``
Also, any v1-specific modules/packages will be moved under ``aimet_torch.v1`` subpackage, such as ``aimet_torch.v1.qc_quantize_op`` or ``aimet_torch.v1.tensor_quantizer``

To keep using ``aimet_torch.v1`` as the default API for backwards-compatibility with the legacy code, set environment variable "AIMET_DEFAULT_API=v1"

Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Dec 20, 2024
1 parent 77a27ad commit c18c881
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 59 deletions.
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ def _get_metadata_and_state_dict(safetensor_file_path: str) -> [dict, dict]:


def _get_default_api() -> Union[Literal["v1"], Literal["v2"]]:
default_api = os.getenv("AIMET_DEFAULT_API", "v1").lower()
default_api = os.getenv("AIMET_DEFAULT_API", "v2").lower()

if default_api not in ("v1", "v2"):
raise RuntimeError("Invalid value specified for environment variable AIMET_DEFAULT_API. "
Expand Down
132 changes: 76 additions & 56 deletions TrainingExtensions/torch/test/python/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,104 +43,124 @@
import pytest

import aimet_torch
from aimet_torch.utils import _get_default_api


def test_default_import():
if _get_default_api() == "v1":
from aimet_torch.v1 import quantsim as default_quantsim
from aimet_torch.v1.quantsim import QuantizationSimModel as default_QuantizationSimModel
from aimet_torch.v1.adaround import adaround_weight as default_adaround_weight
from aimet_torch.v1.adaround.adaround_weight import Adaround as default_Adaround
from aimet_torch.v1 import seq_mse as default_seq_mse
from aimet_torch.v1.seq_mse import apply_seq_mse as default_apply_seq_mse
from aimet_torch.v1.nn.modules import custom as default_custom
from aimet_torch.v1.nn.modules.custom import Add as default_Add
from aimet_torch.v1 import auto_quant as default_auto_quant
from aimet_torch.v1.auto_quant import AutoQuant as default_AutoQuant
from aimet_torch.v1 import quant_analyzer as default_quant_analyzer
from aimet_torch.v1.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer
from aimet_torch.v1 import batch_norm_fold as default_batch_norm_fold
from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale
from aimet_torch.v1 import mixed_precision as default_mixed_precision
from aimet_torch.v1.mixed_precision import choose_mixed_precision as default_choose_mixed_precision
else:
from aimet_torch.v2 import quantsim as default_quantsim
from aimet_torch.v2.quantsim import QuantizationSimModel as default_QuantizationSimModel
from aimet_torch.v2.adaround import adaround_weight as default_adaround_weight
from aimet_torch.v2.adaround.adaround_weight import Adaround as default_Adaround
from aimet_torch.v2 import seq_mse as default_seq_mse
from aimet_torch.v2.seq_mse import apply_seq_mse as default_apply_seq_mse
from aimet_torch.v2.nn.modules import custom as default_custom
from aimet_torch.v2.nn.modules.custom import Add as default_Add
from aimet_torch.v2 import auto_quant as default_auto_quant
from aimet_torch.v2.auto_quant import AutoQuant as default_AutoQuant
from aimet_torch.v2 import quant_analyzer as default_quant_analyzer
from aimet_torch.v2.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer
from aimet_torch.v2 import batch_norm_fold as default_batch_norm_fold
from aimet_torch.v2.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale
from aimet_torch.v2 import mixed_precision as default_mixed_precision
from aimet_torch.v2.mixed_precision import choose_mixed_precision as default_choose_mixed_precision

"""
When: Import from aimet_torch.quantsim
Then: Import should be redirected to aimet_torch.v1.quantsim
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quantsim
"""
from aimet_torch import quantsim
from aimet_torch.v1 import quantsim as v1_quantsim
assert quantsim.QuantizationSimModel is v1_quantsim.QuantizationSimModel
from aimet_torch import quantsim
assert quantsim.QuantizationSimModel is default_quantsim.QuantizationSimModel

from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizationSimModel as v1_QuantizationSimModel
assert QuantizationSimModel is v1_QuantizationSimModel
from aimet_torch.quantsim import QuantizationSimModel
assert QuantizationSimModel is default_QuantizationSimModel

"""
When: Import from aimet_torch.adaround
Then: Import should be redirected to aimet_torch.v1.adaround
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.adaround
"""
from aimet_torch.adaround import adaround_weight
from aimet_torch.v1.adaround import adaround_weight as v1_adaround_weight
assert adaround_weight.Adaround is v1_adaround_weight.Adaround
from aimet_torch.adaround import adaround_weight
assert adaround_weight.Adaround is default_adaround_weight.Adaround

from aimet_torch.adaround.adaround_weight import Adaround
from aimet_torch.v1.adaround.adaround_weight import Adaround as v1_Adaround
assert Adaround is v1_Adaround
from aimet_torch.adaround.adaround_weight import Adaround
assert Adaround is default_Adaround

"""
When: Import from aimet_torch.seq_mse
Then: Import should be redirected to aimet_torch.v1.seq_mse
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.seq_mse
"""
from aimet_torch import seq_mse
from aimet_torch.v1 import seq_mse as v1_seq_mse
assert seq_mse.apply_seq_mse is v1_seq_mse.apply_seq_mse
from aimet_torch import seq_mse
assert seq_mse.apply_seq_mse is default_seq_mse.apply_seq_mse

from aimet_torch.seq_mse import apply_seq_mse
from aimet_torch.v1.seq_mse import apply_seq_mse as v1_apply_seq_mse
assert apply_seq_mse is v1_apply_seq_mse
from aimet_torch.seq_mse import apply_seq_mse
assert apply_seq_mse is default_apply_seq_mse

"""
When: Import from aimet_torch.nn
Then: Import should be redirected to aimet_torch.v1.nn
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.nn
"""
from aimet_torch.nn.modules import custom
from aimet_torch.v1.nn.modules import custom as v1_custom
assert custom.Add is v1_custom.Add
from aimet_torch.nn.modules import custom
assert custom.Add is default_custom.Add

from aimet_torch.nn.modules.custom import Add
from aimet_torch.v1.nn.modules.custom import Add as v1_Add
assert Add is v1_Add
from aimet_torch.nn.modules.custom import Add
assert Add is default_Add

"""
When: Import from aimet_torch.auto_quant
Then: Import should be redirected to aimet_torch.v1.auto_quant
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.auto_quant
"""
from aimet_torch import auto_quant
from aimet_torch.v1 import auto_quant as v1_auto_quant
assert auto_quant.AutoQuant is v1_auto_quant.AutoQuant
from aimet_torch import auto_quant
assert auto_quant.AutoQuant is default_auto_quant.AutoQuant

from aimet_torch.auto_quant import AutoQuant
from aimet_torch.v1.auto_quant import AutoQuant as v1_AutoQuant
assert AutoQuant is v1_AutoQuant
from aimet_torch.auto_quant import AutoQuant
assert AutoQuant is default_AutoQuant

"""
When: Import from aimet_torch.quant_analyzer
Then: Import should be redirected to aimet_torch.v1.quant_analyzer
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quant_analyzer
"""
from aimet_torch import quant_analyzer
from aimet_torch.v1 import quant_analyzer as v1_auto_quant
assert quant_analyzer.QuantAnalyzer is v1_auto_quant.QuantAnalyzer
from aimet_torch import quant_analyzer
assert quant_analyzer.QuantAnalyzer is default_quant_analyzer.QuantAnalyzer

from aimet_torch.quant_analyzer import QuantAnalyzer
from aimet_torch.v1.quant_analyzer import QuantAnalyzer as v1_QuantAnalyzer
assert QuantAnalyzer is v1_QuantAnalyzer
from aimet_torch.quant_analyzer import QuantAnalyzer
assert QuantAnalyzer is default_QuantAnalyzer

"""
When: Import from aimet_torch.batch_norm_fold
Then: Import should be redirected to aimet_torch.v1.batch_norm_fold
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.batch_norm_fold
"""
from aimet_torch import batch_norm_fold
from aimet_torch.v1 import batch_norm_fold as v1_batch_norm_fold
assert batch_norm_fold.fold_all_batch_norms_to_scale is v1_batch_norm_fold.fold_all_batch_norms_to_scale
from aimet_torch import batch_norm_fold
assert batch_norm_fold.fold_all_batch_norms_to_scale is default_batch_norm_fold.fold_all_batch_norms_to_scale

from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as v1_fold_all_batch_norms_to_scale
assert fold_all_batch_norms_to_scale is v1_fold_all_batch_norms_to_scale
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
assert fold_all_batch_norms_to_scale is default_fold_all_batch_norms_to_scale

"""
When: Import from aimet_torch.mixed_precision
Then: Import should be redirected to aimet_torch.v1.mixed_precision
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.mixed_precision
"""
from aimet_torch import mixed_precision
from aimet_torch.v1 import mixed_precision as v1_mixed_precision
assert mixed_precision.choose_mixed_precision is v1_mixed_precision.choose_mixed_precision
from aimet_torch import mixed_precision
assert mixed_precision.choose_mixed_precision is default_mixed_precision.choose_mixed_precision

from aimet_torch.mixed_precision import choose_mixed_precision
from aimet_torch.v1.mixed_precision import choose_mixed_precision as v1_choose_mixed_precision
assert choose_mixed_precision is v1_choose_mixed_precision
from aimet_torch.mixed_precision import choose_mixed_precision
assert choose_mixed_precision is default_choose_mixed_precision


def _get_all_modules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo
from aimet_common.defs import QuantizationDataType
from aimet_torch import utils
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizationSimModel
from aimet_torch.save_utils import SaveUtils


Expand Down
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/test/python/v1/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import torch
from torchvision import models

from aimet_torch.batch_norm_fold import fold_given_batch_norms
from aimet_torch.v1.batch_norm_fold import fold_given_batch_norms
from ..models.test_models import TransposedConvModel
from aimet_torch.model_preparer import prepare_model
from aimet_common.defs import QuantScheme
Expand Down

0 comments on commit c18c881

Please sign in to comment.