diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/batch_norm_fold.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/batch_norm_fold.py index f795dcfd92..67dcf2cb4a 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/batch_norm_fold.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/batch_norm_fold.py @@ -38,12 +38,10 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Union, Dict, Iterable, Set, Any -import numpy as np import torch import torch.nn from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d -import aimet_common.libpymo as libpymo from aimet_common.batch_norm_fold import batch_norm_fold, expand_shape_to_4d from aimet_common.bias_correction import ConvBnPatternHandler, CONV_OP_TYPES, LINEAR_OP_TYPES, BN_OP_TYPES from aimet_common.graph_pattern_matcher import PatternType @@ -69,61 +67,13 @@ BatchNormType = Union[BatchNorm1d, BatchNorm2d] _supported_batchnorms = BatchNormType.__args__ -# Temporary flag to flip underlying implementation. This flag will be removed in the future releases. -USE_PYTHON_IMPL = True - class _BatchNormFoldingNotSupported(RuntimeError): pass class BatchNormFoldBase(ABC): """Handles batch norm folding logic""" - @staticmethod - def _call_mo_batch_norm_fold(weight: torch.Tensor, - bias: torch.Tensor, - bn: BatchNormType, - fold_backward: bool): - """ - Calls C++ batch norm folding API. - - :param weight: Weight or scale tensor to fold BN into. - :param bias: Bias tensor to fold BN into. - :param bn: Batch Norm layer - :param fold_backward: True if BatchNorm comes after Conv/Linear layer - """ - with torch.no_grad(): - bn_params = libpymo.BNParams() - bn_params.gamma = bn.weight.detach().cpu().numpy().reshape(-1) - bn_params.beta = bn.bias.detach().cpu().numpy().reshape(-1) - bn_params.runningMean = bn.running_mean.detach().cpu().numpy().reshape(-1) - sigma = torch.sqrt(bn.running_var + bn.eps) - bn_params.runningVar = sigma.detach().cpu().numpy().reshape(-1) - - weight_tensor = libpymo.TensorParams() - weight_tensor.data = weight.detach().cpu().numpy().reshape(-1) - weight_tensor.shape = np.array(weight.shape) - - bias_tensor = libpymo.TensorParams() - bias_tensor.data = bias.detach().cpu().numpy().reshape(-1) - bias_tensor.shape = np.array(bias.shape) - is_bias_valid = True - - _4d_shape = expand_shape_to_4d(weight_tensor.shape) - try: - orig_shape = weight_tensor.shape - weight_tensor.shape = _4d_shape - _bias = libpymo.fold(bn_params, weight_tensor, bias_tensor, is_bias_valid, fold_backward) - finally: - weight_tensor.shape = orig_shape - - bias.copy_(torch.tensor(_bias, device=bias.device, dtype=bias.dtype) - .reshape_as(bias)) - weight.copy_(torch.tensor(weight_tensor.data, device=weight.device, dtype=weight.dtype) - .reshape_as(weight)) - - - @staticmethod - def _call_py_batch_norm_fold(weight: torch.Tensor, + def _call_batch_norm_fold(weight: torch.Tensor, bias: torch.Tensor, bn: Union[BatchNorm1d, BatchNorm2d], fold_backward: bool): @@ -171,10 +121,8 @@ def _fold_to_weight(cls, conv_linear: LayerType, bn: BatchNormType, fold_backwar dtype=conv_linear.weight.dtype) conv_linear.bias = torch.nn.Parameter(bias) - if USE_PYTHON_IMPL: - cls._call_py_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward) - else: - cls._call_mo_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward) + + cls._call_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward) # Transpose weight back to N, C, H, W for transposed Conv2D, for non-depthwise layers if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1: diff --git a/TrainingExtensions/torch/src/python/aimet_torch/compression_factory.py b/TrainingExtensions/torch/src/python/aimet_torch/compression_factory.py index f263324622..0880a0e4c8 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/compression_factory.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/compression_factory.py @@ -51,12 +51,9 @@ from aimet_torch.defs import SpatialSvdParameters, WeightSvdParameters, ChannelPruningParameters, ModuleCompRatioPair from aimet_torch.layer_selector import ConvFcLayerSelector, ConvNoDepthwiseLayerSelector, ManualLayerSelector from aimet_torch.layer_database import LayerDatabase -from aimet_torch.svd.svd_pruner import SpatialSvdPruner, WeightSvdPruner, PyWeightSvdPruner +from aimet_torch.svd.svd_pruner import SpatialSvdPruner, PyWeightSvdPruner from aimet_torch.channel_pruning.channel_pruner import InputChannelPruner, ChannelPruningCostCalculator -# Temporary flag to flip underlying implementation. This flag will be removed in the future releases. -USE_PYTHON_IMPL = True - class CompressionFactory: """ Factory to construct various AIMET model compression classes based on a scheme """ @@ -216,7 +213,7 @@ def create_weight_svd_algo(cls, model: torch.nn.Module, eval_callback: EvalFunct use_cuda = next(model.parameters()).is_cuda # Create a pruner - pruner = PyWeightSvdPruner() if USE_PYTHON_IMPL else WeightSvdPruner() + pruner = PyWeightSvdPruner() cost_calculator = WeightSvdCostCalculator() comp_ratio_rounding_algo = RankRounder(params.multiplicity, cost_calculator) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/cross_layer_equalization.py b/TrainingExtensions/torch/src/python/aimet_torch/cross_layer_equalization.py index c524709722..dcf992377a 100755 --- a/TrainingExtensions/torch/src/python/aimet_torch/cross_layer_equalization.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/cross_layer_equalization.py @@ -47,7 +47,6 @@ import numpy as np import torch -import aimet_common.libpymo as libpymo from aimet_common.utils import AimetLogger from aimet_common.cross_layer_equalization import ClsLayerType, ClsSetInfo, ClsImpl, HbfImpl from aimet_torch import utils @@ -68,9 +67,6 @@ cls_supported_layers = (torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Conv1d, torch.nn.ConvTranspose1d) cls_supported_activations = (torch.nn.ReLU, torch.nn.PReLU) -# Temporary flag to flip underlying implementation. This flag will be removed in the future releases. -USE_PYTHON_IMPL = True - def get_ordered_list_of_conv_modules(model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple]) -> List: """ @@ -359,7 +355,7 @@ def scale_cls_set_with_conv_layers(cls, cls_set: ClsSet) -> np.ndarray: on_gpu = True module.cpu() - cls_impl = PythonClsImpl() if USE_PYTHON_IMPL else MoClsImpl() + cls_impl = PythonClsImpl() scaling_factor = cls_impl.scale_cls_set_with_conv_layers(cls_set) if on_gpu: @@ -386,7 +382,7 @@ def scale_cls_set_with_depthwise_layers(cls, cls_set: ClsSet) -> [np.ndarray, np on_gpu = True module.cpu() - cls_impl = PythonClsImpl() if USE_PYTHON_IMPL else MoClsImpl() + cls_impl = PythonClsImpl() scaling_factors = cls_impl.scale_cls_set_with_depthwise_layers(cls_set) if on_gpu: @@ -480,200 +476,6 @@ def scale_model(model: torch.nn.Module, input_shapes: Union[Tuple, List[Tuple]] return cls_set_info_list -class MoClsImpl(ClsImpl): - """ - This class implements the CLS algorithm using MO version while following the base Implementation interface. - """ - def scale_cls_set_with_depthwise_layers(self, cls_set) -> [np.ndarray, np.ndarray]: - """ - API to invoke equalize layer params for depth wise separable layers(update for weights and bias is in place) - - :param cls_set: Consecutive Conv layers whose weights and biases need to be equalized. - Second Conv layer is a depth-wise conv and third conv layer is point-wise conv - :return: Scaling factors S_12 and S_23 : numpy arrays - """ - # Create structs for holding layer weights and bias parameters - prev_layer_params = libpymo.EqualizationParams() - curr_layer_params = libpymo.EqualizationParams() - next_layer_params = libpymo.EqualizationParams() - - # Prepare and pack data structures for cls set. - self._pack_params_for_depthwise_conv(cls_set, prev_layer_params, curr_layer_params, next_layer_params) - - # Scales weights and bias for consecutive layers and updates data structures in-place. - scaling_params = libpymo.scaleDepthWiseSeparableLayer(prev_layer_params, curr_layer_params, next_layer_params) - - # Update weight and biases for cls set using updated data structures. - self._update_params_for_depthwise_conv(cls_set, prev_layer_params, curr_layer_params, next_layer_params) - - return scaling_params.scalingMatrix12, scaling_params.scalingMatrix23 - - def scale_cls_set_with_conv_layers(self, cls_set) -> np.ndarray: - """ - API to invoke equalize layer params for regular conv layers (update for weights and bias is in place) - - :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized - :return: Scaling factor S_12 for each conv layer pair: numpy array - """ - # Create structs for holding layer weights and bias parameters - prev_layer_params = libpymo.EqualizationParams() - curr_layer_params = libpymo.EqualizationParams() - - # Prepare and pack data structures for cls set. - self._pack_params_for_conv(cls_set, prev_layer_params, curr_layer_params) - - # Scales weights and bias for consecutive layers and updates data structures in-place. - scaling_factor = libpymo.scaleLayerParams(prev_layer_params, curr_layer_params) - - # Update weight and biases for cls set using updated data structures. - self._update_params_for_conv(cls_set, prev_layer_params, curr_layer_params) - - return scaling_factor - - def _pack_params_for_conv(self, - cls_set, - prev_layer_params: libpymo.EqualizationParams, - curr_layer_params: libpymo.EqualizationParams - ): - """ - Prepare and pack data structure for previous and current layer in given cls set. - - :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized. - :param prev_layer_params: Data structure holding weight and bias for previous layer in cls set. - :param curr_layer_params: Data structure holding weight and bias for current layer in cls set. - """ - self._populate_libpymo_params(cls_set[0], prev_layer_params) - self._populate_libpymo_params(cls_set[1], curr_layer_params) - - if cls_set[0].bias is not None: - prev_layer_params.bias = cls_set[0].bias.detach().numpy() - else: - prev_layer_params.isBiasNone = True - - def _update_params_for_conv(self, - cls_set, - prev_layer_params: libpymo.EqualizationParams, - curr_layer_params: libpymo.EqualizationParams): - """ - Update weight and biases for cls set using updated data structures. - - :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized. - :param prev_layer_params: Data structure holding weight and bias for previous layer in cls set. - :param curr_layer_params: Data structure holding weight and bias for current layer in cls set. - """ - self._update_module_from_libpymo(cls_set[0], prev_layer_params) - self._update_module_from_libpymo(cls_set[1], curr_layer_params) - - if cls_set[0].bias is not None: - cls_set[0].bias.data = torch.from_numpy(np.reshape(prev_layer_params.bias, - prev_layer_params.weightShape[0])) - cls_set[0].bias.data = cls_set[0].bias.data.type(torch.FloatTensor) - - def _pack_params_for_depthwise_conv(self, - cls_set, - prev_layer_params: libpymo.EqualizationParams, - curr_layer_params: libpymo.EqualizationParams, - next_layer_params: libpymo.EqualizationParams): - """ - Prepare and pack data structure for previous, current and next layer in given cls set. - - :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized. - :param prev_layer_params: Data structure holding weight and bias for previous layer in cls set. - :param curr_layer_params: Data structure holding weight and bias for current layer in cls set. - :param next_layer_params: Data structure holding weight and bias for next layer in cls set. - """ - # cls_set 0 - self._populate_libpymo_params(cls_set[0], prev_layer_params) - - # cls_set 1 - assert cls_set[1].groups > 1 - curr_layer_params.weight = cls_set[1].weight.detach().numpy().flatten() - curr_layer_params.weightShape = np.array(cls_set[1].weight.shape) - if len(curr_layer_params.weightShape) == 3: - curr_layer_params.weightShape = curr_layer_params.weightShape + [1] - - # cls_set 2 - self._populate_libpymo_params(cls_set[2], next_layer_params) - - if cls_set[0].bias is not None: - prev_layer_params.bias = cls_set[0].bias.detach().numpy() - else: - prev_layer_params.isBiasNone = True - - if cls_set[1].bias is not None: - curr_layer_params.bias = cls_set[1].bias.detach().numpy() - else: - curr_layer_params.isBiasNone = True - - def _update_params_for_depthwise_conv(self, - cls_set, - prev_layer_params: libpymo.EqualizationParams, - curr_layer_params: libpymo.EqualizationParams, - next_layer_params: libpymo.EqualizationParams): - """ - Update weight and biases for cls set using updated data structures. - - :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized. - :param prev_layer_params: Data structure holding weight and bias for previous layer in cls set. - :param curr_layer_params: Data structure holding weight and bias for current layer in cls set. - :param next_layer_params: Data structure holding weight and bias for next layer in cls set. - """ - self._update_module_from_libpymo(cls_set[0], prev_layer_params) - self._update_module_from_libpymo(cls_set[1], curr_layer_params) - self._update_module_from_libpymo(cls_set[2], next_layer_params) - - if cls_set[0].bias is not None: - cls_set[0].bias.data = torch.from_numpy(np.reshape(prev_layer_params.bias, - prev_layer_params.weightShape[0])) - cls_set[0].bias.data = cls_set[0].bias.data.type(torch.FloatTensor) - - if cls_set[1].bias is not None: - cls_set[1].bias.data = torch.from_numpy(np.reshape(curr_layer_params.bias, - curr_layer_params.weightShape[0])) - cls_set[1].bias.data = cls_set[1].bias.data.type(torch.FloatTensor) - - @staticmethod - def _populate_libpymo_params(module: torch.nn.Module, layer_params: libpymo.EqualizationParams): - """ - Populate libpymo object. - - :param module: pytorch module. - :param layer_params: libpymo object. - """ - weight_set = module.weight - - # Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv - if isinstance(module, torch.nn.ConvTranspose2d) and module.groups == 1: - weight_set = weight_set.permute(1, 0, 2, 3).contiguous() - if isinstance(module, torch.nn.ConvTranspose1d) and module.groups == 1: - weight_set = weight_set.permute(1, 0, 2).contiguous() - - layer_params.weight = weight_set.detach().numpy().reshape(-1) - layer_params.weightShape = np.array(weight_set.shape) - if len(layer_params.weightShape) == 3: - layer_params.weightShape = layer_params.weightShape + [1] - - @staticmethod - def _update_module_from_libpymo(module: torch.nn.Module, layer_param: libpymo.EqualizationParams): - """ - Update module parameter from the libpymo object. - - :param module: pytorch module. - :param layer_param: libpymo object. - """ - if isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): - layer_param.weightShape = layer_param.weightShape[:-1] - module.weight.data = torch.from_numpy(np.reshape(layer_param.weight, - layer_param.weightShape)) - module.weight.data = module.weight.data.type(torch.FloatTensor) - - # Transpose weight back to N, C, H, W for transposed Conv2D/1D - if isinstance(module, torch.nn.ConvTranspose2d) and module.groups == 1: - module.weight.data = module.weight.data.permute(1, 0, 2, 3).contiguous() - if isinstance(module, torch.nn.ConvTranspose1d) and module.groups == 1: - module.weight.data = module.weight.data.permute(1, 0, 2).contiguous() - - class PythonClsImpl(ClsImpl): """ This class implements the CLS algorithm using Python version while following the base Implementation interface. @@ -840,112 +642,10 @@ def bias_fold(cls, cls_set_info_list: List[ClsSetInfo], (cls_pair_info.layer1 not in bn_layers): continue - # Pick an implementation version based on user provided flag. - hbf_impl = PythonHbfImpl() if USE_PYTHON_IMPL else MoHbfImpl() + hbf_impl = PythonHbfImpl() hbf_impl.bias_fold(cls_pair_info, bn_layers) -class MoHbfImpl(HbfImpl): - """ - This class implements the HBF algorithm using MO version while following the base Implementation interface. - """ - def bias_fold(self, cls_pair_info, bn_layers): - """ - Bias fold implementation using Model optimization (c++) version. - - :param cls_pair_info: Layer pairs that were scaled using CLS and related information. - :param bn_layers: Dictionary with Key being Conv/Linear layer and value being corresponding folded BN layer. - """ - # Create data structures for holding layer weights and bias parameters. - prev_layer_params = libpymo.LayerParams() - curr_layer_params = libpymo.LayerParams() - prev_layer_bn_params = libpymo.BNParamsHighBiasFold() - - # Prepare and pack data structures for high bias fold. - self._pack_bn_layer_params(cls_pair_info, bn_layers, prev_layer_bn_params) - self._pack_previous_and_current_layer_params(cls_pair_info, prev_layer_params, curr_layer_params) - - # Update bias for previous and current layer and data structures in-place. - libpymo.updateBias(prev_layer_params, curr_layer_params, prev_layer_bn_params) - - # Set updated biases for previous and current layer. - self._update_previous_and_current_layer_bias(cls_pair_info, prev_layer_params, curr_layer_params) - - @staticmethod - def _pack_bn_layer_params(cls_pair_info: ClsSetInfo.ClsSetLayerPairInfo, - bn_layers: Dict[torch.nn.Module, torch.nn.BatchNorm2d], - prev_layer_bn_params: libpymo.BNParamsHighBiasFold): - """ - Helper method to pack batch norm layer parameter for high bias fold. - - :param cls_pair_info: Layer pairs that were scaled using CLS and related information. - :param bn_layers: Dictionary with Key being Conv/Linear layer and value being corresponding folded BN layer. - :param prev_layer_bn_params: Data structure to pack batch norm parameter. - """ - scaling_parameter = cls_pair_info.scale_factor - - # Scaling gamma and beta parameter of batch norm layer - prev_layer_bn_params.gamma = bn_layers[cls_pair_info.layer1].weight.detach().cpu().numpy().reshape(-1) - prev_layer_bn_params.beta = bn_layers[cls_pair_info.layer1].bias.detach().cpu().numpy().reshape(-1) - - if len(scaling_parameter) != len(prev_layer_bn_params.gamma) or \ - len(scaling_parameter) != len(prev_layer_bn_params.beta): - raise ValueError("High Bias absorption is not supported for networks with fold-forward BatchNorms") - prev_layer_bn_params.gamma = np.divide(prev_layer_bn_params.gamma, scaling_parameter) - prev_layer_bn_params.beta = np.divide(prev_layer_bn_params.beta, scaling_parameter) - - @staticmethod - def _pack_previous_and_current_layer_params(cls_pair_info, prev_layer_params, curr_layer_params): - """ - Helper method to pack information of previous and current layer. - - :param cls_pair_info: Layer pairs that were scaled using CLS and related information. - :param prev_layer_params: Data structure to pack previous layer parameters. - :param curr_layer_params: Data structure to pack current layer parameters. - """ - prev_layer_params.activationIsRelu = cls_pair_info.relu_activation_between_layers - prev_layer_params.bias = cls_pair_info.layer1.bias.detach().cpu().numpy() - - weight = cls_pair_info.layer2.weight - - if isinstance(cls_pair_info.layer2, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): - weight = torch.unsqueeze(weight, dim=-1) - - # Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv - if isinstance(cls_pair_info.layer2, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)) and \ - cls_pair_info.layer2.groups == 1: - weight = weight.permute(1, 0, 2, 3) - - curr_layer_params.bias = cls_pair_info.layer2.bias.detach().cpu().numpy() - curr_layer_params.weight = weight.detach().cpu().numpy().reshape(-1) - curr_layer_params.weightShape = np.array(weight.shape) - - @staticmethod - def _update_previous_and_current_layer_bias(cls_pair_info: ClsSetInfo.ClsSetLayerPairInfo, - prev_layer_params: libpymo.LayerParams, - curr_layer_params: libpymo.LayerParams): - """ - Update biases for previous and current layer. - - :param cls_pair_info: Layer pairs that were scaled using CLS and related information. - :param prev_layer_params: Data structure holding weight and bias for previous layer in cls set. - :param curr_layer_params: Data structure holding weight and bias for current layer in cls set. - """ - prev_layer_bias_shape = cls_pair_info.layer1.weight.shape[0] - if (isinstance(cls_pair_info.layer1, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d))) and \ - (cls_pair_info.layer1.groups == 1): - prev_layer_bias_shape = cls_pair_info.layer1.weight.shape[1] - - with torch.no_grad(): - cls_pair_info.layer1.bias.copy_( - torch.from_numpy(np.reshape(prev_layer_params.bias, prev_layer_bias_shape))).to( - device=cls_pair_info.layer1.bias.device, dtype=cls_pair_info.layer1.bias.dtype) - - cls_pair_info.layer2.bias.copy_( - torch.from_numpy(np.reshape(curr_layer_params.bias, curr_layer_params.weightShape[0]))).to( - device=cls_pair_info.layer2.bias.device, dtype=cls_pair_info.layer2.bias.dtype) - - class PythonHbfImpl(HbfImpl): """ This class implements the HBF algorithm using python version while following the base Implementation interface. diff --git a/TrainingExtensions/torch/test/python/test_bn_fold.py b/TrainingExtensions/torch/test/python/test_bn_fold.py index 6b2f5c1c30..4c7f8532a9 100644 --- a/TrainingExtensions/torch/test/python/test_bn_fold.py +++ b/TrainingExtensions/torch/test/python/test_bn_fold.py @@ -130,29 +130,12 @@ def forward(self, *inputs): x = self.fc(x) return x -@contextmanager -def _use_python_impl(flag: bool): - orig_flag = batch_norm_fold.USE_PYTHON_IMPL - try: - batch_norm_fold.USE_PYTHON_IMPL = flag - yield - finally: - batch_norm_fold.USE_PYTHON_IMPL = orig_flag - - -@pytest.fixture(params=[True, False]) -def use_python_impl(request): - param: bool = request.param - - with _use_python_impl(param): - yield - class TestTrainingExtensionBnFold: @pytest.mark.cuda @pytest.mark.parametrize("device", ['cpu', 'cuda']) - def test_fold_resnet18(self, use_python_impl, device): + def test_fold_resnet18(self, device): torch.manual_seed(10) model = models.resnet18().to(device) _initialize_bn_params(model) @@ -175,29 +158,18 @@ def test_fold_resnet18(self, use_python_impl, device): @pytest.mark.cuda @pytest.mark.parametrize("device", ['cpu', 'cuda']) def test_python_impl(self, device): - try: - flag = batch_norm_fold.USE_PYTHON_IMPL - torch.manual_seed(10) - model = models.resnet18().eval().to(device) - _initialize_bn_params(model) - model_copy = copy.deepcopy(model) - - batch_norm_fold.USE_PYTHON_IMPL = True - layer_list = [(model.layer2[0].conv1, model.layer2[0].bn1)] - fold_given_batch_norms(model, layer_list) - - batch_norm_fold.USE_PYTHON_IMPL = False - layer_list = [(model_copy.layer2[0].conv1, model_copy.layer2[0].bn1)] - fold_given_batch_norms(model_copy, layer_list) - - # Ensure that the weight parameter is updated correctly after bn fold. - assert torch.allclose(model.layer2[0].conv1.weight, model_copy.layer2[0].conv1.weight) - assert not isinstance(model.layer2[0].bn1, torch.nn.BatchNorm2d) - assert not isinstance(model_copy.layer2[0].bn1, torch.nn.BatchNorm2d) - finally: - batch_norm_fold.USE_PYTHON_IMPL = flag - - def test_fold_bn_before_conv_no_bias(self, use_python_impl): + torch.manual_seed(10) + model = models.resnet18().eval().to(device) + _initialize_bn_params(model) + + layer_list = [(model.layer2[0].conv1, model.layer2[0].bn1)] + fold_given_batch_norms(model, layer_list) + + # Ensure that the weight parameter is updated correctly after bn fold. + assert not isinstance(model.layer2[0].bn1, torch.nn.BatchNorm2d) + + + def test_fold_bn_before_conv_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -235,7 +207,7 @@ def forward(self, x): assert model.conv2.weight.device == model.conv2.bias.device assert model.conv2.weight.dtype == model.conv2.bias.dtype - def test_fold_bn_before_conv_with_bias(self, use_python_impl): + def test_fold_bn_before_conv_with_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -271,7 +243,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-1) - def test_fold_bn_before_conv_with_padding(self, use_python_impl): + def test_fold_bn_before_conv_with_padding(self): class MyModel(torch.nn.Module): def __init__(self): @@ -306,7 +278,7 @@ def forward(self, x): assert isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_before_conv_transpose(self, use_python_impl): + def test_fold_bn_before_conv_transpose(self): class MyModel(torch.nn.Module): def __init__(self): @@ -341,7 +313,7 @@ def forward(self, x): assert isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_filter_conv_bn_pair(self, use_python_impl): + def test_filter_conv_bn_pair(self): invalid_fold_forward = [torch.nn.Conv2d(10, 20, 3, padding=1), torch.nn.Conv2d(10, 10, 2, groups=10), torch.nn.Conv2d(10, 20, 2, groups=2), @@ -378,7 +350,7 @@ def test_filter_conv_bn_pair(self, use_python_impl): is_valid = [_is_valid_bn_fold(layer, True) for layer in valid_fold_backward] assert all(is_valid) - def test_fold_bn_after_conv_no_bias(self, use_python_impl): + def test_fold_bn_after_conv_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -415,7 +387,7 @@ def forward(self, x): assert model.conv1.weight.device == model.conv1.bias.device assert model.conv1.weight.dtype == model.conv1.bias.dtype - def test_fold_bn_after_conv_depthwise(self, use_python_impl): + def test_fold_bn_after_conv_depthwise(self): class MyModel(torch.nn.Module): def __init__(self): @@ -447,7 +419,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_after_transposed_conv_depthwise(self, use_python_impl): + def test_fold_bn_after_transposed_conv_depthwise(self): class MyModel(torch.nn.Module): def __init__(self): @@ -479,7 +451,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_after_conv_with_bias(self, use_python_impl): + def test_fold_bn_after_conv_with_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -513,7 +485,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm2d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_before_linear_layer_no_bias(self, use_python_impl): + def test_fold_bn_before_linear_layer_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -548,7 +520,7 @@ def forward(self, x): assert model.fc1.weight.device == model.fc1.bias.device assert model.fc1.weight.dtype == model.fc1.bias.dtype - def test_fold_bn_before_linear_layer_with_bias(self, use_python_impl): + def test_fold_bn_before_linear_layer_with_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -580,7 +552,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm1d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_after_linear_layer_no_bias(self, use_python_impl): + def test_fold_bn_after_linear_layer_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -615,7 +587,7 @@ def forward(self, x): assert model.fc1.weight.device == model.fc1.bias.device assert model.fc1.weight.dtype == model.fc1.bias.dtype - def test_fold_bn_after_linear_layer_with_bias(self, use_python_impl): + def test_fold_bn_after_linear_layer_with_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -647,7 +619,7 @@ def forward(self, x): assert not isinstance(model.bn1, torch.nn.BatchNorm1d) assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_find_batch_norms_to_fold(self, use_python_impl): + def test_find_batch_norms_to_fold(self): model = MyModel().eval() _initialize_bn_params(model) @@ -662,7 +634,7 @@ def test_find_batch_norms_to_fold(self, use_python_impl): assert (model.bn2, model.conv3) in bn_conv_pairs assert len(bn_picked) == 2 - def test_bn_fold_auto_mode_transposed_conv2d(self, use_python_impl): + def test_bn_fold_auto_mode_transposed_conv2d(self): torch.manual_seed(10) model = TransposedConvModel().eval() @@ -681,7 +653,7 @@ def test_bn_fold_auto_mode_transposed_conv2d(self, use_python_impl): assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) assert len(folded_pairs) == 2 - def test_find_batch_norms_to_fold_multi_input(self, use_python_impl): + def test_find_batch_norms_to_fold_multi_input(self): model = TwoInputs().eval() _initialize_bn_params(model) inp_shapes = [(1, 3, 32, 32), (1, 3, 20, 20)] @@ -697,7 +669,7 @@ def test_find_batch_norms_to_fold_multi_input(self, use_python_impl): assert (model.conv1, model.bn1) in conv_bn_pairs assert (model.conv2, model.bn2) in conv_bn_pairs - def test_bn_fold_auto_mode(self, use_python_impl): + def test_bn_fold_auto_mode(self): torch.manual_seed(10) model = MyModel().eval() @@ -715,7 +687,7 @@ def test_bn_fold_auto_mode(self, use_python_impl): assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) assert len(folded_pairs) == 2 - def test_fold_auto_mode_with_bn_after_Conv1d_layer(self, use_python_impl): + def test_fold_auto_mode_with_bn_after_Conv1d_layer(self): class MyModel(torch.nn.Module): def __init__(self): @@ -746,7 +718,7 @@ def forward(self, x): assert 1 == len(bn_pairs) assert (model.conv1d, orig_bn) in bn_pairs - def test_bn_conversion(self, use_python_impl): + def test_bn_conversion(self): class MyModel(torch.nn.Module): def __init__(self): @@ -779,7 +751,7 @@ def forward(self, x): assert 0 == len(bn_pairs) assert (model.conv1d, orig_bn) not in bn_pairs - def test_fold_manual_with_bn_after_Conv1d_layer_no_bias(self, use_python_impl): + def test_fold_manual_with_bn_after_Conv1d_layer_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -811,7 +783,7 @@ def forward(self, x): assert model.conv1d.weight.dtype == model.conv1d.bias.dtype @pytest.mark.cuda - def test_multi_gpu(self, use_python_impl): + def test_multi_gpu(self): torch.manual_seed(10) model = MyModel() model.eval() @@ -826,7 +798,7 @@ def test_multi_gpu(self, use_python_impl): output_after = model(random_input) assert torch.allclose(output_before, output_after, rtol=1.e-2) - def test_fold_bn_before_Conv1d_with_bias(self, use_python_impl): + def test_fold_bn_before_Conv1d_with_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -857,7 +829,7 @@ def forward(self, x): assert (model.conv1d, orig_bn) in bn_pairs assert torch.allclose(baseline_output, output_after_fold, rtol=1.e-2) - def test_fold_bn_before_Conv1d_no_bias(self, use_python_impl): + def test_fold_bn_before_Conv1d_no_bias(self): class MyModel(torch.nn.Module): def __init__(self): @@ -897,7 +869,7 @@ def forward(self, x): assert model.conv1d.weight.device == model.conv1d.bias.device assert model.conv1d.weight.dtype == model.conv1d.bias.dtype - def test_bn_fold_conv3d_fold_backward(self, use_python_impl): + def test_bn_fold_conv3d_fold_backward(self): torch.random.manual_seed(10) model = Conv3dModel() @@ -915,7 +887,7 @@ def test_bn_fold_conv3d_fold_backward(self, use_python_impl): bn_modules = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm3d)] assert not bn_modules - def test_bn_fold_conv3d_fold_forward(self, use_python_impl): + def test_bn_fold_conv3d_fold_forward(self): torch.random.manual_seed(10) model = Conv3dModel1() diff --git a/TrainingExtensions/torch/test/python/test_cross_layer_scaling.py b/TrainingExtensions/torch/test/python/test_cross_layer_scaling.py index c6a81d2aa6..b023c826a0 100755 --- a/TrainingExtensions/torch/test/python/test_cross_layer_scaling.py +++ b/TrainingExtensions/torch/test/python/test_cross_layer_scaling.py @@ -160,29 +160,12 @@ def forward(self, x): x = self.conv2(x) return x -@contextmanager -def _use_python_impl(flag: bool): - orig_flag = cle.USE_PYTHON_IMPL - try: - cle.USE_PYTHON_IMPL = flag - yield - finally: - cle.USE_PYTHON_IMPL = orig_flag - - -@pytest.fixture(params=[True, False]) -def use_python_impl(request): - param: bool = request.param - - with _use_python_impl(param): - yield - class TestTrainingExtensionsCrossLayerScaling: @pytest.mark.cuda @pytest.mark.parametrize("device", ['cpu', 'cuda']) - def test_verify_cross_layer_scaling(self, use_python_impl, device): + def test_verify_cross_layer_scaling(self, device): # Get trained MNIST model torch.manual_seed(10) model = MyModel().eval().to(device) @@ -203,7 +186,7 @@ def test_verify_cross_layer_scaling(self, use_python_impl, device): assert (np.allclose(range_conv1_after_scaling, range_conv2_after_scaling)) assert (np.allclose(baseline_output, output_after_scaling, rtol=1.e-2)) - def test_top_level_api(self, use_python_impl): + def test_top_level_api(self): torch.manual_seed(10) model = MyModel().eval() input_shape = (2, 10, 24, 24) @@ -218,7 +201,7 @@ def test_top_level_api(self, use_python_impl): @pytest.mark.cuda @pytest.mark.parametrize("device", ['cpu', 'cuda']) - def test_verify_cross_layer_for_multiple_pairs(self, use_python_impl, device): + def test_verify_cross_layer_for_multiple_pairs(self, device): # Get trained MNIST model model = MyModel().eval().to(device) # Call API @@ -235,7 +218,7 @@ def test_verify_cross_layer_for_multiple_pairs(self, use_python_impl, device): assert not torch.equal(model.conv2.weight, w2) assert not torch.equal(model.conv3.weight, w3) - def test_verify_cross_layer_scaling_depthwise_separable_layer_mobilnet(self, use_python_impl): + def test_verify_cross_layer_scaling_depthwise_separable_layer_mobilnet(self): torch.manual_seed(10) model = MockMobileNetV1().eval() @@ -266,7 +249,7 @@ def test_verify_cross_layer_scaling_depthwise_separable_layer_mobilnet(self, use @pytest.mark.cuda @pytest.mark.parametrize("device", ['cpu', 'cuda']) - def test_verify_cross_layer_scaling_depthwise_separable_layer_multiple_triplets(self, use_python_impl, device): + def test_verify_cross_layer_scaling_depthwise_separable_layer_multiple_triplets(self, device): torch.manual_seed(10) model = MockMobileNetV1().eval().to(device) @@ -282,7 +265,7 @@ def test_verify_cross_layer_scaling_depthwise_separable_layer_multiple_triplets( assert not torch.equal(model.model[1][3].weight, w2) assert not torch.equal(model.model[2][3].weight, w3) - def test_find_layer_groups_to_scale_for_network_with_residuals(self, use_python_impl): + def test_find_layer_groups_to_scale_for_network_with_residuals(self): torch.manual_seed(10) model = MockMobileNetV2() model.eval() @@ -533,7 +516,7 @@ def test_find_cls_sets_mobilenetv1(self): for layer_tuple in layer_pairs: print(layer_tuple) - def test_auto_mobilenetv1(self, use_python_impl): + def test_auto_mobilenetv1(self): torch.manual_seed(10) model = MockMobileNetV1() model.eval() @@ -544,7 +527,7 @@ def test_auto_mobilenetv1(self, use_python_impl): scale_factors = CrossLayerScaling.scale_model(model, (1, 3, 224, 224), random_input) assert 8 == len(scale_factors) - def test_auto_cls_custom_model(self, use_python_impl): + def test_auto_cls_custom_model(self): torch.manual_seed(10) model = MyModel() model.eval() @@ -564,7 +547,7 @@ def test_auto_cls_custom_model(self, use_python_impl): assert torch.allclose(output_before_scale, output_after_scale) @pytest.mark.cuda - def test_auto_cls_custom_model_multi_gpu(self, use_python_impl): + def test_auto_cls_custom_model_multi_gpu(self): torch.manual_seed(10) model = MyModel() @@ -582,7 +565,7 @@ def test_auto_cls_custom_model_multi_gpu(self, use_python_impl): output_after_scale = model(random_input) assert torch.allclose(output_before_scale, output_after_scale, rtol=1.e-2) - def test_auto_cle_custom_model(self, use_python_impl): + def test_auto_cle_custom_model(self): torch.manual_seed(10) model = MyModel() @@ -596,7 +579,7 @@ def test_auto_cle_custom_model(self, use_python_impl): assert torch.allclose(output_before_equalize, output_after_equalize) @pytest.mark.cuda - def test_auto_cle_custom_model_multi_gpu(self, use_python_impl): + def test_auto_cle_custom_model_multi_gpu(self): torch.manual_seed(10) model = MyModel() @@ -612,7 +595,7 @@ def test_auto_cle_custom_model_multi_gpu(self, use_python_impl): output_after_equalize = model(random_input) assert torch.allclose(output_before_equalize, output_after_equalize, rtol=1.e-2) - def test_auto_cle_two_inputs_model(self, use_python_impl): + def test_auto_cle_two_inputs_model(self): model = TwoInputsModel().eval() model_copy = copy.deepcopy(model) @@ -629,7 +612,7 @@ def test_auto_cle_two_inputs_model(self, use_python_impl): output_after_equalize = model_copy(*model_input_list) assert torch.allclose(output_before_equalize, output_after_equalize) - def test_auto_transposed_conv2d_model(self, use_python_impl): + def test_auto_transposed_conv2d_model(self): torch.manual_seed(10) model = TransposedConvModel() @@ -643,7 +626,7 @@ def test_auto_transposed_conv2d_model(self, use_python_impl): assert np.allclose(baseline_output, output_after_scaling, rtol=1.e-2) assert 10 == len(scale_factors[0].cls_pair_info_list[0].scale_factor) - def test_auto_depthwise_transposed_conv_model(self, use_python_impl): + def test_auto_depthwise_transposed_conv_model(self): torch.manual_seed(0) model = torch.nn.Sequential( torch.nn.Conv2d(5, 10, 3), @@ -667,7 +650,7 @@ def test_auto_depthwise_transposed_conv_model(self, use_python_impl): assert 2 == len(scale_factors) assert 2 == len(scale_factors[0].cls_pair_info_list) - def test_cle_for_float32_and_int64_input_model(self, use_python_impl): + def test_cle_for_float32_and_int64_input_model(self): model = Float32AndInt64InputModel().to(torch.device('cpu')) model.eval() @@ -691,200 +674,49 @@ class TestTrainingExtensionsCrossLayerScalingPythonOnly: @pytest.mark.cuda def test_cle_using_python_impl(self): - """ Compare MO and python implementation for CLE """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - random_input = torch.rand(2, 10, 24, 24).cuda() - model = MyModel().eval().cuda() - model_copy = copy.deepcopy(model).eval() - # original outputs - output = model(random_input) - - # equalize using MO - cle.USE_PYTHON_IMPL = False - equalize_model(model, (2, 10, 24, 24), dummy_input=random_input) - output_using_mo = model(random_input) - - # equalize using python - cle.USE_PYTHON_IMPL = True - equalize_model(model_copy, (2, 10, 24, 24), dummy_input=random_input) - output_using_python = model_copy(random_input) - - assert torch.allclose(output_using_mo, output_using_python) - assert torch.allclose(output, output_using_mo) - assert torch.allclose(output, output_using_python) - finally: - cle.USE_PYTHON_IMPL = flag + torch.manual_seed(10) + random_input = torch.rand(2, 10, 24, 24).cuda() + model = MyModel().eval().cuda() + model_copy = copy.deepcopy(model).eval() + # original outputs + output = model(random_input) + + equalize_model(model_copy, (2, 10, 24, 24), dummy_input=random_input) + output_using_python = model_copy(random_input) + + assert torch.allclose(output, output_using_python) + @pytest.mark.cuda def test_scale_cls_set_with_conv_layers_using_python_impl(self): """ Compare scale_cls_set_with_conv_layers API """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = MyModel().cuda().eval() - model_copy = copy.deepcopy(model).eval() - random_input = torch.rand((2, 10, 24, 24)).cuda() - - # original outputs - output = model(random_input) - - # Invoke MO implementation - cle.USE_PYTHON_IMPL = False - CrossLayerScaling.scale_cls_set_with_conv_layers((model.conv1, model.conv2)) - output_using_mo = model(random_input) - - # Invoke python implementation - cle.USE_PYTHON_IMPL = True - CrossLayerScaling.scale_cls_set_with_conv_layers((model_copy.conv1, model_copy.conv2)) - output_using_python = model_copy(random_input) - - # Verify the outputs. - assert torch.allclose(output_using_mo, output_using_python) - assert torch.allclose(output, output_using_mo) - assert torch.allclose(output, output_using_python) - - # Verify the weights. - assert torch.allclose(model.conv1.weight, model_copy.conv1.weight) - assert torch.allclose(model.conv2.weight, model_copy.conv2.weight) - finally: - cle.USE_PYTHON_IMPL = flag + torch.manual_seed(10) + model = MyModel().cuda().eval() + random_input = torch.rand((2, 10, 24, 24)).cuda() - @pytest.mark.cuda - def test_cls_using_python_impl_mobilenetv1(self): - """ Compare MO and python implementation for CLS """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = MockMobileNetV1().cuda().eval() - model_copy = copy.deepcopy(model).cuda().eval() - dummy_input = torch.rand((1, 3, 224, 224)).cuda() - - # BN fold - fold_all_batch_norms(model, (1, 3, 224, 224), dummy_input=dummy_input) - fold_all_batch_norms(model_copy, (1, 3, 224, 224), dummy_input=dummy_input) - - # CLS using MO and python. - cle.USE_PYTHON_IMPL = False - scale_factors_mo = CrossLayerScaling.scale_model(model, (1, 3, 224, 224), dummy_input) - cle.USE_PYTHON_IMPL = True - scale_factors_python = CrossLayerScaling.scale_model(model_copy, (1, 3, 224, 224), dummy_input) - - assert len(scale_factors_mo) == 8 - assert len(scale_factors_python) == 8 - - # Verify the outputs. - assert torch.allclose(model(dummy_input), model_copy(dummy_input)) - - # Verify the weights - assert torch.allclose(model.model[0][0].weight, model_copy.model[0][0].weight) - finally: - cle.USE_PYTHON_IMPL = flag + # original outputs + output = model(random_input) - @pytest.mark.cuda - @pytest.mark.parametrize('device', ['cpu', 'cuda']) - def test_bias_fold_using_python_impl(self, device): - """ Verify bias fold API using python implementation """ - class Model(torch.nn.Module): - def __init__(self): - super(Model, self).__init__() - self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=2) - self.bn1 = torch.nn.BatchNorm2d(16) - self.conv2 = torch.nn.ConvTranspose2d(16, 32, kernel_size=3) - self.bn2 = torch.nn.BatchNorm2d(32) - self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=3) - for m in self.modules(): - if isinstance(m, torch.nn.BatchNorm2d): - torch.nn.init.normal_(m.weight) - torch.nn.init.constant_(m.bias, 4) - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.conv3(x) - return x - - def _verify_bias_fold(model, dummy_input): - - folded_pairs = batch_norm_fold.fold_all_batch_norms(model, (1, 3, 224, 224), dummy_input=dummy_input) - bn_dict = {} - for conv, bn in folded_pairs: - bn_dict[conv] = bn - - # Create a list of consecutive conv layers to be equalized and scale them. - consecutive_layer_list = [(model.conv1, model.conv2), (model.conv2, model.conv3)] - scaling_factor_list = CrossLayerScaling.scale_cls_sets(consecutive_layer_list) - - cls_set_info_list = \ - [ClsSetInfo(ClsSetInfo.ClsSetLayerPairInfo(model.conv1, model.conv2, scaling_factor_list[0], True)), - ClsSetInfo(ClsSetInfo.ClsSetLayerPairInfo(model.conv2, model.conv3, scaling_factor_list[1], True))] - - conv1_bias_before = model.conv1.bias.clone().cpu() - conv2_bias_before = model.conv2.bias.clone().cpu() - conv3_bias_before = model.conv3.bias.clone().cpu() - - # Fold the biases. - HighBiasFold.bias_fold(cls_set_info_list, bn_dict) - - conv1_bias = model.conv1.bias.detach().cpu() - conv2_bias = model.conv2.bias.detach().cpu() - conv3_bias = model.conv3.bias.detach().cpu() - - assert not torch.equal(conv1_bias_before, conv1_bias) - assert not torch.equal(conv2_bias_before, conv2_bias) - assert not torch.equal(conv3_bias_before, conv3_bias) - - return conv1_bias, conv2_bias, conv3_bias - - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = Model().eval().to(device) - model_copy = copy.deepcopy(model).to(device) - dummy_input = torch.randn(1, 3, 10, 10).to(device) - - # invoke with MO (c++) implementation - cle.USE_PYTHON_IMPL = False - conv1_bias_mo, conv2_bias_mo, conv3_bias_mo = _verify_bias_fold(model, dummy_input) - # invoke with python implementation - cle.USE_PYTHON_IMPL = True - conv1_bias_p, conv2_bias_p, conv3_bias_p = _verify_bias_fold(model_copy, dummy_input) - - assert torch.allclose(conv1_bias_mo, conv1_bias_p) - assert torch.allclose(conv2_bias_mo, conv2_bias_p) - assert torch.allclose(conv3_bias_mo, conv3_bias_p) - assert torch.allclose(model(dummy_input), model_copy(dummy_input), rtol=1.e-2) - finally: - cle.USE_PYTHON_IMPL = flag + CrossLayerScaling.scale_cls_set_with_conv_layers((model.conv1, model.conv2)) + output_using_python = model(random_input) + # Verify the outputs. + assert torch.allclose(output, output_using_python) + @pytest.mark.parametrize("groups", [1, 10]) def test_compare_scale_factors(self, groups): - """ compare scale factors using with MO and python implementation """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = torch.nn.Sequential( - torch.nn.ConvTranspose2d(10, 10, 3, groups=groups), - torch.nn.Conv2d(10, 10, 3), - ).eval() - - with torch.no_grad(): - model[0].weight *= model[0].weight * 100 - - model_copy = copy.deepcopy(model).eval() - dummy_input = torch.rand((1, 10, 32, 32)) - cle.USE_PYTHON_IMPL = True - py_scale_factors = CrossLayerScaling.scale_model(model_copy, dummy_input=dummy_input) - cle.USE_PYTHON_IMPL = False - mo_scale_factors = CrossLayerScaling.scale_model(model, dummy_input=dummy_input) - for py, mo in zip(py_scale_factors[0].cls_pair_info_list[0].scale_factor, - mo_scale_factors[0].cls_pair_info_list[0].scale_factor): - assert np.isclose(py, mo) - finally: - cle.USE_PYTHON_IMPL = flag + torch.manual_seed(10) + model = torch.nn.Sequential( + torch.nn.ConvTranspose2d(10, 10, 3, groups=groups), + torch.nn.Conv2d(10, 10, 3), + ).eval() + with torch.no_grad(): + model[0].weight *= model[0].weight * 100 + + dummy_input = torch.rand((1, 10, 32, 32)) + py_scale_factors = CrossLayerScaling.scale_model(model, dummy_input=dummy_input) + def _verify_ranges(module_0, module_1): if isinstance(module_0, torch.nn.ConvTranspose2d) and module_0.groups == 1: weight_0 = module_0.weight.detach().permute(1, 0, 2, 3) @@ -896,10 +728,8 @@ def _verify_ranges(module_0, module_1): # Verify that weights are scaled back to similar ranges _verify_ranges(model[0], model[1]) - _verify_ranges(model_copy[0], model_copy[1]) - def test_divide_by_zero(self, use_python_impl): - """ Ensure scale factors are computed using with MO and python implementation """ + def test_divide_by_zero(self): torch.manual_seed(10) model = torch.nn.Sequential( torch.nn.ConvTranspose2d(10, 10, 3, groups=10), @@ -911,61 +741,24 @@ def test_divide_by_zero(self, use_python_impl): CrossLayerScaling.scale_model(model, dummy_input=dummy_input) assert not torch.isnan(model[0].weight).any() - def test_bias_fold_for_convtranspose1d(self): - """ Verify bias fold for ConvTranspose1d """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = torch.nn.Sequential( - torch.nn.Conv1d(10, 10, 3), - torch.nn.BatchNorm1d(10), - torch.nn.ConvTranspose1d(10, 10, 3) - ).eval() - # Initialize BN parameters - torch.nn.init.normal_(model[1].weight) - torch.nn.init.normal_(model[1].bias) - dummy_input = torch.randn(1, 10, 32) - model_copy = copy.deepcopy(model).eval() - cle.USE_PYTHON_IMPL = True - equalize_model(model, dummy_input=dummy_input) - cle.USE_PYTHON_IMPL = False - equalize_model(model_copy, dummy_input=dummy_input) - assert torch.allclose(model[0].bias, model_copy[0].bias, rtol=1.e-1) - assert torch.allclose(model[2].bias, model_copy[2].bias, rtol=1.e-1) - finally: - cle.USE_PYTHON_IMPL = flag - def test_divide_by_zero_with_depthwise(self): - """ Ensure scale factors are computed using with MO and python implementation """ - flag = cle.USE_PYTHON_IMPL - try: - torch.manual_seed(10) - model = torch.nn.Sequential( - torch.nn.Conv2d(10, 10, 3), - torch.nn.ReLU(), - torch.nn.Conv2d(10, 10, 3, groups=10), - torch.nn.ReLU(), - torch.nn.Conv2d(10, 10, 1), - torch.nn.ReLU(), - ).eval() - dummy_input = torch.randn(1, 10, 32, 32) - with torch.no_grad(): - model[2].weight[0, :, :, :] = 0 - - model_copy = copy.deepcopy(model).eval() - cle.USE_PYTHON_IMPL = True - CrossLayerScaling.scale_model(model, dummy_input=dummy_input) - cle.USE_PYTHON_IMPL = False - CrossLayerScaling.scale_model(model_copy, dummy_input=dummy_input) - - assert not torch.isnan(model[0].weight).any() - assert not torch.isnan(model[2].weight).any() - assert not torch.isnan(model[4].weight).any() - assert torch.allclose(model[0].weight, model_copy[0].weight) - assert torch.allclose(model[2].weight, model_copy[2].weight) - assert torch.allclose(model[4].weight, model_copy[4].weight) - - with torch.no_grad(): - assert torch.allclose(model(dummy_input), model_copy(dummy_input), rtol=1.e-2) - finally: - cle.USE_PYTHON_IMPL = flag + torch.manual_seed(10) + model = torch.nn.Sequential( + torch.nn.Conv2d(10, 10, 3), + torch.nn.ReLU(), + torch.nn.Conv2d(10, 10, 3, groups=10), + torch.nn.ReLU(), + torch.nn.Conv2d(10, 10, 1), + torch.nn.ReLU(), + ).eval() + dummy_input = torch.randn(1, 10, 32, 32) + with torch.no_grad(): + model[2].weight[0, :, :, :] = 0 + + model_copy = copy.deepcopy(model).eval() + CrossLayerScaling.scale_model(model, dummy_input=dummy_input) + + assert not torch.isnan(model[0].weight).any() + assert not torch.isnan(model[2].weight).any() + assert not torch.isnan(model[4].weight).any() + \ No newline at end of file diff --git a/TrainingExtensions/torch/test/python/test_weight_svd.py b/TrainingExtensions/torch/test/python/test_weight_svd.py index 9a1d377275..1c3a2e08c6 100644 --- a/TrainingExtensions/torch/test/python/test_weight_svd.py +++ b/TrainingExtensions/torch/test/python/test_weight_svd.py @@ -54,23 +54,6 @@ logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Test) -@contextmanager -def _use_python_impl(flag: bool): - orig_flag = cf_svd.USE_PYTHON_IMPL - try: - cf_svd.USE_PYTHON_IMPL = flag - yield - finally: - cf_svd.USE_PYTHON_IMPL = orig_flag - - -@pytest.fixture(params=[True, False]) -def use_python_impl(request): - param: bool = request.param - - with _use_python_impl(param): - yield - class MnistModel(nn.Module): def __init__(self): @@ -258,7 +241,7 @@ def forward(self, *inputs): class TestWeightSvdPruning: - def test_prune_layer(self, use_python_impl): + def test_prune_layer(self): model = mnist_model.Net()