Skip to content

Commit

Permalink
Add optimize rounding Function (#2494)
Browse files Browse the repository at this point in the history
* Add optimize rounding Function

Signed-off-by: Harshita Mangal <[email protected]>

* Add call to adaround optimizer from adaround weights

Signed-off-by: Harshita Mangal <[email protected]>
  • Loading branch information
quic-mangal authored and quic-bharathr committed Sep 13, 2024
1 parent 89d9a7e commit 2b6b0f8
Show file tree
Hide file tree
Showing 10 changed files with 470 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def sample_and_place_all_acts_on_cpu(self, dataset) -> Tuple:
model_inputs = next(iterator)
inp_data, out_data = self.sample_acts(create_input_dict(self._org_model.model, model_inputs))

all_inp_data.append(inp_data)
all_out_data.append(out_data)
all_inp_data.append(inp_data[0])
all_out_data.append(out_data[0])

if batch_index == len(dataset) - 1:
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,93 +38,273 @@

""" Adaround optimizer """

from typing import Tuple, Dict
from typing import Union, Tuple, Dict
import numpy as np
import onnx
from onnx import onnx_pb, numpy_helper
import torch
import torch.nn.functional as functional

from onnx import numpy_helper
from torch.utils.data import Dataset

# Import AIMET specific modules
from aimet_common.utils import AimetLogger
from aimet_onnx.adaround.activation_sampler import ActivationSampler
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.adaround.utils import ModuleInfo, read_attributes_for_op
from aimet_onnx.utils import create_input_dict
from aimet_torch.adaround.adaround_loss import AdaroundLoss, AdaroundHyperParameters
from aimet_torch.adaround.adaround_tensor_quantizer import AdaroundTensorQuantizer
from aimet_torch.adaround.adaround_optimizer import AdaroundOptimizer as TorchAdaroundOptimizer

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
BATCH_SIZE = 32
EMPIRICAL_THRESHOLD = 3 / 4
DATA_SIZE_IN_BITS = 32
ACTIVATION_MAP = {'Relu': torch.nn.ReLU(), 'PRelu': torch.nn.PReLU(), 'Tanh': torch.nn.Tanh(),
'Clip': torch.nn.ReLU6(), 'Sigmoid': torch.nn.Sigmoid(), 'Softmax': torch.nn.Softmax()}


class AdaroundOptimizer:
"""
Optimizes the weight rounding of quantized wrapper module
"""
@classmethod
def _compute_recons_metrics(cls, quant_module: ModuleInfo, act_func: torch.nn.Module, inp_data: torch.Tensor,
out_data: torch.Tensor, param_to_adaround_tensor_quantizer: Dict) -> Tuple[float, float]:
def adaround_module(cls, module: ModuleInfo, quantized_input_name: str,
orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel,
act_func: Union[torch.nn.Module, None], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
"""
Adaround module
:param module: Original module's information
:param quantized_input_name: Name of input to the quantized layer/ layer to be adarounded
:param orig_model: The original, un quantized, model
:param quant_model: QuantSim model
:param act_func: Activation function
:param cached_dataset: Cached dataset
yielded from the data loader
:param opt_params: Optimization parameters
:param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary
:param use_cuda: If we should use cuda
:param device: CUDA device ID
"""
# pylint: disable=too-many-arguments

# Optimize weight rounding
cls._optimize_rounding(module, quantized_input_name, orig_model, quant_model, act_func, cached_dataset,
opt_params, param_to_adaround_tensor_quantizer, use_cuda, device)

# After optimization, set the optimized layer's rounding mode to "Hard rounding"
param_to_adaround_tensor_quantizer[module.params['weight'].name].use_soft_rounding = False

@classmethod
def _optimize_rounding(cls, module: ModuleInfo, quantized_input_name,
orig_model: onnx_pb.ModelProto, quant_model: QuantizationSimModel,
act_func: Union[None, str], cached_dataset: Dataset,
opt_params: AdaroundHyperParameters, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0):
"""
Optimizes the weight rounding of quantized wrapper module
:param module: Original module
:param quantized_input_name: Name of input to the quantized layer/ layer to be adarounded
:param orig_model: The original, un quantized, model
:param quant_model: QuantSim model
:param act_func: Activation function
:param cached_dataset: Cached dataset
:param opt_params: Optimization parameters
:param param_to_adaround_tensor_quantizer: Param name to adaround tensor quantizer dictionary
"""
# pylint: disable=too-many-locals, too-many-arguments
adaround_quantizer = param_to_adaround_tensor_quantizer[module.params['weight'].name]
torch_device = 'cpu'
if use_cuda:
torch_device = 'cuda:' + str(device)
weights = torch.from_numpy(numpy_helper.to_array(module.params['weight'].tensor)).to(torch_device)
enable_grad(weights)

# pylint: disable=protected-access
adaround_quantizer._broadcast_offset_delta(weights)
adaround_quantizer._initialize_alpha(weights, adaround_quantizer.broadcasted_delta)

assert adaround_quantizer.use_soft_rounding, 'optimization should use soft rounding only.'
assert adaround_quantizer.alpha is not None, 'alpha parameter should be initialized.'

# Create and set up Adam optimizer with parameter 'alpha' to be optimized
optimizer = torch.optim.Adam([adaround_quantizer.alpha])

# Check if we can cache intermediate activation data.
model_inputs = cached_dataset[0]
act_sampler = ActivationSampler(module.outputs[0], quantized_input_name, orig_model, quant_model,
use_cuda, device)
inp_data, out_data = act_sampler.sample_acts(create_input_dict(orig_model.model, model_inputs))
inp_data_torch, out_data_torch = torch.from_numpy(inp_data[0]), torch.from_numpy(out_data[0])
use_cache_acts_data = TorchAdaroundOptimizer._can_cache_acts_data(len(cached_dataset), inp_data_torch.shape,
out_data_torch.shape)

if use_cache_acts_data and AdaroundOptimizer.enable_caching_acts_data():
logger.debug("Caching intermediate activations data for optimization.")
all_inp_data, all_orig_out_data = act_sampler.sample_and_place_all_acts_on_cpu(cached_dataset)
all_inp_data, all_out_data = torch.from_numpy(all_inp_data[0]), \
torch.from_numpy(all_orig_out_data[0])
# Try to put all cached activations data on GPU for faster optimization if possible.
if use_cuda:
all_inp_data, all_orig_out_data = TorchAdaroundOptimizer._place_cached_acts_data(all_inp_data, all_out_data,
torch_device)

for iteration in range(opt_params.num_iterations):
if use_cache_acts_data and AdaroundOptimizer.enable_caching_acts_data():
indices = torch.randperm(all_inp_data.size(0))[:BATCH_SIZE]
inp_data = all_inp_data[indices].to(device)
orig_out_data = all_orig_out_data[indices].to(device)
else:
model_inputs = cached_dataset[np.random.randint(len(cached_dataset))]
inp_data, orig_out_data = act_sampler.sample_acts(model_inputs)


# Clear alpha's gradients before optimization step
optimizer.zero_grad()

# Get the module's output activations using AdaRounded weights
quant_out_data = cls._compute_output_with_adarounded_weights(weights, module, inp_data, adaround_quantizer)

# If followed by an activation function
if act_func is not None:
orig_out_data = ACTIVATION_MAP[act_func](orig_out_data)
quant_out_data = ACTIVATION_MAP[act_func](quant_out_data)

# Calculate total loss
recon_loss = AdaroundLoss.compute_recon_loss(quant_out_data, orig_out_data)
round_loss = AdaroundLoss.compute_round_loss(adaround_quantizer.alpha, opt_params, iteration)
total_loss = recon_loss + round_loss

# Back propagate and Update the parameter 'alpha'
total_loss.backward()
optimizer.step()

if iteration == 0 or iteration % 100 == 0:
logger.debug("After iterations=%d, Total loss=%5f, Recons. loss=%5f, Rounding loss=%5f",
iteration, float(total_loss), float(recon_loss), float(round_loss))

adaround_quantizer.use_soft_rounding = True
adarounded_weights = adaround_quantizer.adaround_weights(weights)
weights = adarounded_weights.detach().cpu().numpy().tobytes()
weight_name = module.params['weight'].name
update_sim_weight(quant_model, weights, weight_name)

@classmethod
def _compute_recons_metrics(cls, quant_module: ModuleInfo, act_func: Union[None, str], inp_data: torch.Tensor,
out_data: torch.Tensor, param_to_adaround_tensor_quantizer: Dict,
use_cuda: bool, device: int = 0) -> Tuple[float, float]:
"""
Compute Mean square error of output activations using soft rounding which maps alpha parameter
between zero and one and hard rounding which maps to exact zero and one
:param quant_module: Quantized wrapper module
:param act_func: Activation function
:param inp_data: Input data to quantized wrapper module
:param out_data: Output data from module
:param param_to_adaround_tensor_quantizer: Dict
:param use_cuda: Bool, true if we use GPU
:param device: Cuda device
:return: Reconstruction error using hard rounding and soft rounding
"""
adaround_quantizer = param_to_adaround_tensor_quantizer[quant_module.params['weight'].name]

torch_device = 'cpu'
if use_cuda:
torch_device = 'cuda:' + str(device)
weights = torch.from_numpy(numpy_helper.to_array(quant_module.params['weight'].tensor)).to(torch_device)
inp_data = inp_data.to(torch_device)
# Enable hard rounding and get quantized wrapper module's output
adaround_quantizer.use_soft_rounding = False
out_data_hard = cls._compute_output_with_adarounded_weights(quant_module, inp_data, adaround_quantizer)
out_data_hard = cls._compute_output_with_adarounded_weights(weights, quant_module, inp_data, adaround_quantizer)

# Enable soft rounding and get quantized wrapper module's output
adaround_quantizer.use_soft_rounding = True
out_data_soft = cls._compute_output_with_adarounded_weights(quant_module, inp_data, adaround_quantizer)
out_data_soft = cls._compute_output_with_adarounded_weights(weights, quant_module, inp_data, adaround_quantizer)

# If followed by an activation function
if act_func is not None:
out_data = act_func(out_data)
out_data_soft = act_func(out_data_soft)
out_data_hard = act_func(out_data_hard)
out_data = ACTIVATION_MAP[act_func](out_data)
out_data_soft = ACTIVATION_MAP[act_func](out_data_soft)
out_data_hard = ACTIVATION_MAP[act_func](out_data_hard)

recons_err_soft = functional.mse_loss(out_data_soft, out_data)
recons_err_hard = functional.mse_loss(out_data_hard, out_data)

return float(recons_err_hard), float(recons_err_soft)

@staticmethod
def _compute_output_with_adarounded_weights(quant_module, inp_data: torch.Tensor,
def _compute_output_with_adarounded_weights(weights: torch.Tensor, quant_module, inp_data: torch.Tensor,
adaround_quantizer: AdaroundTensorQuantizer):
"""
Compute output of AdaroundSupportedModules with adarounded weights
:param weights: Torch tensor weights to be adarounded
:param quant_module: Quantized wrapper module
:param inp_data: The input data to be used for computing the output
:param adaround_quantizer: Adaround tensor quantizer
:return: output of the module computed with AdaRounded weights
"""
# pylint: disable=protected-access
# Compute adarounded weights
weights = torch.from_numpy(numpy_helper.to_array(quant_module.params['weight'].tensor))
device = 'cpu'
if inp_data.is_cuda:
device = inp_data.device

adarounded_weights = adaround_quantizer.adaround_weights(weights)

if quant_module.type == 'Conv':
attributes = read_attributes_for_op(quant_module)
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor))
bias = None
if 'bias' in quant_module.params:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
out_data = functional.conv2d(inp_data, adarounded_weights, bias=bias, stride=attributes['strides'],
dilation=attributes['dilations'], padding=attributes['pads'][0],
groups=attributes['group'])
elif quant_module.type == 'ConvTranspose':
attributes = read_attributes_for_op(quant_module)
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor))
bias = None
if 'bias' in quant_module.params:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
out_data = functional.conv_transpose2d(inp_data, adarounded_weights, bias=bias, stride=attributes['strides'],
dilation=attributes['dilations'], padding=attributes['pads'][0],
groups=attributes['group'])
elif quant_module.type in ['Gemm', 'MatMul']:
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor))
bias = torch.from_numpy(numpy_helper.to_array(quant_module.params['bias'].tensor)).to(device)
out_data = functional.linear(inp_data, adarounded_weights, bias=bias)

else:
raise ValueError('AdaRound is not supported for the module type: ', quant_module.type)

return out_data

@staticmethod
def enable_caching_acts_data() -> bool:
"""
Function to enable/disable caching intermediate activation data. By default, it returns True.
"""
return True


def enable_grad(tensor: torch.Tensor):
"""
Enables gradient
:param tensor: Tensor for which we should enable grad
"""
if tensor.is_leaf:
tensor.requires_grad = True

def update_sim_weight(quant_model: onnx.ModelProto, weights: onnx.TensorProto, weight_name: str):
"""
Updates weights in sim for a given name
:param quant_model: Quantized model
:param weights: Weight tensor
:param weight_name: Name of the weight to be updated
"""
for tensor in quant_model.model.graph.initializer:
if tensor.name == weight_name:
tensor.raw_data = weights
break
assert "Could not find %s in QuantSim model", weight_name
Loading

0 comments on commit 2b6b0f8

Please sign in to comment.