Skip to content

Commit

Permalink
Remove USE_PYTHON_IMPL flag (#3669)
Browse files Browse the repository at this point in the history
* Removing USE_PYTHON_IMPL flag

Signed-off-by: Priyanka Dangi <[email protected]>
  • Loading branch information
quic-pdangi authored Dec 19, 2024
1 parent 5494006 commit 3295d37
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 722 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@

from abc import ABC, abstractmethod
from typing import List, Tuple, Union, Dict, Iterable, Set, Any
import numpy as np
import torch
import torch.nn
from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d

import aimet_common.libpymo as libpymo
from aimet_common.batch_norm_fold import batch_norm_fold, expand_shape_to_4d
from aimet_common.bias_correction import ConvBnPatternHandler, CONV_OP_TYPES, LINEAR_OP_TYPES, BN_OP_TYPES
from aimet_common.graph_pattern_matcher import PatternType
Expand All @@ -69,61 +67,13 @@
BatchNormType = Union[BatchNorm1d, BatchNorm2d]
_supported_batchnorms = BatchNormType.__args__

# Temporary flag to flip underlying implementation. This flag will be removed in the future releases.
USE_PYTHON_IMPL = True

class _BatchNormFoldingNotSupported(RuntimeError):
pass

class BatchNormFoldBase(ABC):
"""Handles batch norm folding logic"""

@staticmethod
def _call_mo_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: BatchNormType,
fold_backward: bool):
"""
Calls C++ batch norm folding API.
:param weight: Weight or scale tensor to fold BN into.
:param bias: Bias tensor to fold BN into.
:param bn: Batch Norm layer
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
"""
with torch.no_grad():
bn_params = libpymo.BNParams()
bn_params.gamma = bn.weight.detach().cpu().numpy().reshape(-1)
bn_params.beta = bn.bias.detach().cpu().numpy().reshape(-1)
bn_params.runningMean = bn.running_mean.detach().cpu().numpy().reshape(-1)
sigma = torch.sqrt(bn.running_var + bn.eps)
bn_params.runningVar = sigma.detach().cpu().numpy().reshape(-1)

weight_tensor = libpymo.TensorParams()
weight_tensor.data = weight.detach().cpu().numpy().reshape(-1)
weight_tensor.shape = np.array(weight.shape)

bias_tensor = libpymo.TensorParams()
bias_tensor.data = bias.detach().cpu().numpy().reshape(-1)
bias_tensor.shape = np.array(bias.shape)
is_bias_valid = True

_4d_shape = expand_shape_to_4d(weight_tensor.shape)
try:
orig_shape = weight_tensor.shape
weight_tensor.shape = _4d_shape
_bias = libpymo.fold(bn_params, weight_tensor, bias_tensor, is_bias_valid, fold_backward)
finally:
weight_tensor.shape = orig_shape

bias.copy_(torch.tensor(_bias, device=bias.device, dtype=bias.dtype)
.reshape_as(bias))
weight.copy_(torch.tensor(weight_tensor.data, device=weight.device, dtype=weight.dtype)
.reshape_as(weight))


@staticmethod
def _call_py_batch_norm_fold(weight: torch.Tensor,
def _call_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: Union[BatchNorm1d, BatchNorm2d],
fold_backward: bool):
Expand Down Expand Up @@ -171,10 +121,8 @@ def _fold_to_weight(cls, conv_linear: LayerType, bn: BatchNormType, fold_backwar
dtype=conv_linear.weight.dtype)
conv_linear.bias = torch.nn.Parameter(bias)

if USE_PYTHON_IMPL:
cls._call_py_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)
else:
cls._call_mo_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)

cls._call_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)

# Transpose weight back to N, C, H, W for transposed Conv2D, for non-depthwise layers
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@
from aimet_torch.defs import SpatialSvdParameters, WeightSvdParameters, ChannelPruningParameters, ModuleCompRatioPair
from aimet_torch.layer_selector import ConvFcLayerSelector, ConvNoDepthwiseLayerSelector, ManualLayerSelector
from aimet_torch.layer_database import LayerDatabase
from aimet_torch.svd.svd_pruner import SpatialSvdPruner, WeightSvdPruner, PyWeightSvdPruner
from aimet_torch.svd.svd_pruner import SpatialSvdPruner, PyWeightSvdPruner
from aimet_torch.channel_pruning.channel_pruner import InputChannelPruner, ChannelPruningCostCalculator

# Temporary flag to flip underlying implementation. This flag will be removed in the future releases.
USE_PYTHON_IMPL = True

class CompressionFactory:
""" Factory to construct various AIMET model compression classes based on a scheme """

Expand Down Expand Up @@ -216,7 +213,7 @@ def create_weight_svd_algo(cls, model: torch.nn.Module, eval_callback: EvalFunct
use_cuda = next(model.parameters()).is_cuda

# Create a pruner
pruner = PyWeightSvdPruner() if USE_PYTHON_IMPL else WeightSvdPruner()
pruner = PyWeightSvdPruner()
cost_calculator = WeightSvdCostCalculator()
comp_ratio_rounding_algo = RankRounder(params.multiplicity, cost_calculator)

Expand Down
Loading

0 comments on commit 3295d37

Please sign in to comment.