Skip to content

Commit

Permalink
Fix Seq mse output tensor bug (#3688)
Browse files Browse the repository at this point in the history
* Fixed input name and output name in case of sim model

Signed-off-by: Harsh Peswani <[email protected]>
  • Loading branch information
quic-alanmaha authored Dec 20, 2024
1 parent 367295d commit a583285
Showing 1 changed file with 49 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self,
:param model: float model
:param sim: QuantizationSimModel object
:param data_loader: Data loader
:param data_loader: Torch Dataloader
:param params: Sequential MSE parameters
"""

Expand Down Expand Up @@ -151,6 +151,50 @@ def __init__(self,
self.static_tensor_name_to_proto)
self.quantizers_to_be_disabled = self._get_quantizers_to_be_disabled() # check this

def _update_value_info_for_output(self, node):
"""
Updates the value info for output of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""

input_name = node.input[0]
output_name = node.output[0]
if input_name in self._sim_extractor.vimap and output_name not in self._sim_extractor.vimap:
value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name])
value_info_for_output.name = node.output[0]
self._sim_extractor.vimap[node.output[0]] = value_info_for_output

def _update_value_info_for_input(self, node):
"""
Updates the value info for input of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""

input_name = node.input[0]
output_name = node.output[0]
if output_name in self._sim_extractor.vimap and input_name not in self._sim_extractor.vimap:
value_info_for_input = copy.deepcopy(self._sim_extractor.vimap[output_name])
value_info_for_input.name = node.input[0]
self._sim_extractor.vimap[node.input[0]] = value_info_for_input

def _update_value_info_for_graph_output(self):
"""
Updates the value info for input of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""

for value_info in self.model.model.graph.output:
self._float_extractor.vimap[value_info.name] = value_info

for value_info in self.sim.model.model.graph.output:
self._sim_extractor.vimap[value_info.name] = value_info

def _update_value_info(self):
"""
Updates the value info for sim model.
Expand All @@ -159,11 +203,10 @@ def _update_value_info(self):

for node in self.sim.model.nodes():
if node.op_type == "QcQuantizeOp":
input_name = node.input[0]
if input_name in self._sim_extractor.vimap:
value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name])
value_info_for_output.name = node.output[0]
self._sim_extractor.vimap[node.output[0]] = value_info_for_output
self._update_value_info_for_output(node)
self._update_value_info_for_input(node)

self._update_value_info_for_graph_output()

def _fill_static_tensor_name_to_proto(self):
"""
Expand Down

0 comments on commit a583285

Please sign in to comment.