Skip to content

Commit

Permalink
MMP Misc enhancements (#3692)
Browse files Browse the repository at this point in the history
* Make activation a required input for set_precision calls
* Always use input quantizer at idx=0 for Concat Op
* Error out if non data movement functional is encountered

Signed-off-by: yathindra kota <[email protected]>
  • Loading branch information
quic-ykota authored Dec 20, 2024
1 parent 896b015 commit ea061ea
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 44 deletions.
27 changes: 18 additions & 9 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/cg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@
# =============================================================================
"""Utilities to traverse model graph"""

from typing import Dict, Optional, Generator, Tuple
from typing import Dict, Optional, Generator, Tuple, Union
from dataclasses import dataclass
import functools

import torch

from aimet_common.connected_graph.connectedgraph_utils import CG_SPLIT
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.meta.operation import Op as CG_Op
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel
Expand Down Expand Up @@ -156,23 +158,30 @@ def get_cg_op_from_module(self, module):
""" Helper functions to lookup CG_Op corresponding to the given module """
return self.module_to_cg_op_mapping[module]

def get_parent_module_at_input_idx(self, module, input_idx) -> torch.nn.Module:
def get_valid_parent_module_at_input_idx(self, module, input_idx) -> Union[torch.nn.Module, None]:
"""
Traverses upstream to determine the parent module provided input idx
Traverses upstream to determine the parent module provided input idx.
This method errors out if a functional is encountered which is not a data movement op.
:param module: torch.nn.Module contained within the QuantSim object
:param input_idx: input idx to determine the parent module
:return: parent torch.nn.Module providing input idx
"""
cg_op = self.get_cg_op_from_module(module)
parent_cg_op = cg_op.inputs[input_idx].producer
parent_module = self.get_module_from_cg_op(parent_cg_op)

while parent_module is None and parent_cg_op is not None:
parent_cg_op = parent_cg_op.inputs[0].producer
parent_module = self.get_module_from_cg_op(parent_cg_op)

return parent_module
while parent_cg_op:
if parent_cg_op.get_module():
return parent_cg_op.get_module()

if parent_cg_op.type in ConnectedGraph.math_invariant_types or parent_cg_op.type == CG_SPLIT:
# Split op or "functional data movement" op is encountered. Query its parent.
parent_cg_op = parent_cg_op.inputs[0].producer
else:
raise RuntimeError(f"Parent of {cg_op.dotted_name} is a functional which is not a data movement op"
f"CG name of the op:{parent_cg_op.dotted_name}. Considering removing this functional "
f"to process")
return None

def get_child_module_at_output(self, module):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,18 @@ def _store_user_request(self, request_type: RequestType, module: Union[torch.nn.

@overload
def set_precision(self, module: torch.nn.Module,
activation: Union[List[SupportedDType], SupportedDType, None] = None,
activation: Union[List[SupportedDType], SupportedDType],
param: Optional[Dict[str, SupportedDType]] = None):
...

@overload
def set_precision(self, module_type: Type[torch.nn.Module],
activation: Union[List[SupportedDType], SupportedDType, None] = None,
activation: Union[List[SupportedDType], SupportedDType],
param: Optional[Dict[str, SupportedDType]] = None):
...

def set_precision(self, arg: Union[torch.nn.Module, Type[torch.nn.Module]],
activation: Union[List[SupportedDType], SupportedDType, None] = None,
activation: Union[List[SupportedDType], SupportedDType],
param: Optional[Dict[str, SupportedDType]] = None):
"""
:param arg: Module can be of type torch.nn.Module or the type of the module.
Expand Down Expand Up @@ -133,31 +133,31 @@ def set_precision(self, arg: Union[torch.nn.Module, Type[torch.nn.Module]],
else:
raise TypeError("arg is neither a torch.nn.Module nor of Type[torch.nn.Module]")

def set_model_input_precision(self, activations: Union[List[SupportedDType], Tuple[SupportedDType], SupportedDType, None]):
def set_model_input_precision(self, activation: Union[List[Optional[SupportedDType]], Tuple[Optional[SupportedDType]], SupportedDType]):
"""
Activation precision which needs to be set to the model inputs
:param activations: Activation dtypes for inputs of the model
:param activation: Activation dtypes for inputs of the model
"""
broadcasted_activations = broadcast_tuples(activations, self.mp_handler.cg_traverser.model_inputs)
for activation, model_input in zip(flatten_list(broadcasted_activations),
broadcasted_activations = broadcast_tuples(activation, self.mp_handler.cg_traverser.model_inputs)
for act, model_input in zip(flatten_list(broadcasted_activations),
flatten_list(self.mp_handler.cg_traverser.model_inputs)):
if activation is not None:
if activation not in get_args(SupportedDType):
if act:
if act not in get_args(SupportedDType):
raise ValueError("Supported inputs for activation are ", get_args(SupportedDType))
self._store_user_request(RequestType.set_model_input_precision, model_input, activation)
self._store_user_request(RequestType.set_model_input_precision, model_input, act)

def set_model_output_precision(self, activations: Union[List[SupportedDType], Tuple[SupportedDType], SupportedDType, None]):
def set_model_output_precision(self, activation: Union[List[Optional[SupportedDType]], Tuple[Optional[SupportedDType]], SupportedDType]):
"""
Activation precision which needs to be set to the model outputs
:param activations: Activation dtypes for outputs of the model
:param activation: Activation dtypes for outputs of the model
"""
broadcasted_activations = broadcast_tuples(activations, self.mp_handler.cg_traverser.model_outputs)
for activation, model_output in zip(flatten_list(broadcasted_activations),
broadcasted_activations = broadcast_tuples(activation, self.mp_handler.cg_traverser.model_outputs)
for act, model_output in zip(flatten_list(broadcasted_activations),
flatten_list(self.mp_handler.cg_traverser.model_outputs)):
if activation is not None:
if activation not in get_args(SupportedDType):
if act:
if act not in get_args(SupportedDType):
raise ValueError("Supported inputs for activation are ", get_args(SupportedDType))
self._store_user_request(RequestType.set_model_output_precision, model_output, activation)
self._store_user_request(RequestType.set_model_output_precision, model_output, act)

@overload
def apply(self, log_file: str = './mmp_log.txt', strict: bool = True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
import copy
from typing import Dict, List, Tuple, Optional, Union, IO

import torch.nn

from aimet_common.defs import QuantizationDataType, QuantScheme
from aimet_common.utils import AimetLogger
from aimet_torch.v2.nn.modules.custom import QuantizedConcat
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.nn import BaseQuantizationMixin
Expand Down Expand Up @@ -394,7 +397,7 @@ def _propagate_request_upstream_helper(module):
if module.input_quantizers[in_idx] is not None:
continue

parent_module = self.cg_traverser.get_parent_module_at_input_idx(module, in_idx)
parent_module = self.cg_traverser.get_valid_parent_module_at_input_idx(module, in_idx)
if parent_module is None:
logger.warning(f"Warning: unable to propagate request at {module} upward. "
"Parent module could not be found.")
Expand Down Expand Up @@ -429,6 +432,23 @@ def _propagate_request_upstream_helper(module):
_propagate_request_upstream_helper(module)
return mp_requests

def _get_child_module_and_idx(self, module: torch.nn.Module):
"""
Helper to get the child module and their input idxes consistent with QuantSim interpretation
:param module: module to return the child modules and their idxes
"""
child_module_idxs = self.cg_traverser.get_child_module_at_output(module)
# Even if concat op has more than one input, in QuantSim there is only one quantizer added.
# This check always returns idx=0 for those modules
updated_child_module_idxs = []
for child_module, input_idx in child_module_idxs:
if isinstance(child_module, QuantizedConcat):
input_idx = 0
updated_child_module_idxs.append((child_module, input_idx))
return updated_child_module_idxs


def _resolve_request_outputs(self, mp_requests, log_file: IO):
"""
Determine if output candidates from request at the provided module should be applied or discarded
Expand All @@ -442,7 +462,7 @@ def _resolve_request_outputs_helper(module):
return

# If the output request at this module came from a downstream consumer, return without changing candidate
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(module)
child_modules_and_idxs = self._get_child_module_and_idx(module)
for child_module, input_idx in child_modules_and_idxs:
child_request = mp_requests.get(child_module)
if child_request and child_request.input_candidates and \
Expand Down Expand Up @@ -516,9 +536,9 @@ def _apply_new_request_for_module(module, request) -> bool:
# module does not have a request. Create a new one based on the request input
self._update_request_at_module(mp_requests,
module,
request.input_candidates[0] if request.output_candidates else None,
request.input_candidates[0] if request.output_candidates and len(request.output_candidates) > 0 else None,
copy.deepcopy(request.param_candidate) if len(module.param_quantizers.keys()) else None,
request.output_candidates[0] if request.output_candidates else None,
request.output_candidates[0] if request.output_candidates and len(request.output_candidates) > 0 else None,
strict=strict)
mp_requests[module].id = request.id

Expand Down Expand Up @@ -547,7 +567,7 @@ def _apply_new_request_for_module(module, request) -> bool:
if mp_request.input_candidates:
# resolve contention at the inputs using input candidates
for in_idx, input_candidate in enumerate(mp_request.input_candidates):
parent_module = self.cg_traverser.get_parent_module_at_input_idx(current_module, in_idx)
parent_module = self.cg_traverser.get_valid_parent_module_at_input_idx(current_module, in_idx)

# if input candidate is not present in the request (say, the request is from set_model_output_precision) or
# if input quantizer is present for the layer(along with input candidate) then no need to resolve any contention
Expand All @@ -556,7 +576,7 @@ def _apply_new_request_for_module(module, request) -> bool:

# if parent has output quantizer, propagate this request to all other children
if any(parent_module.output_quantizers):
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(parent_module)
child_modules_and_idxs = self._get_child_module_and_idx(parent_module)
for child_module, _ in child_modules_and_idxs:
new_request = MpRequest(id=mp_request.id,
input_candidates=[input_candidate] * len(child_module.input_quantizers),
Expand All @@ -576,7 +596,7 @@ def _apply_new_request_for_module(module, request) -> bool:
if mp_request.output_candidates:
# resolve at output using output candidate, if the module has output quantizer, then no need to resolve at output
if not any(current_module.output_quantizers):
child_modules_and_idxs = self.cg_traverser.get_child_module_at_output(current_module)
child_modules_and_idxs = self._get_child_module_and_idx(current_module)
for child_module, _ in child_modules_and_idxs:
new_request = MpRequest(id=mp_request.id,
input_candidates=mp_request.output_candidates * len(child_module.input_quantizers),
Expand Down
Loading

0 comments on commit ea061ea

Please sign in to comment.