Skip to content

Commit

Permalink
[Fix] replace torch.norm with torch.linalg.norm during onnx exportati…
Browse files Browse the repository at this point in the history
…on (#2847)
  • Loading branch information
Ben-Louis authored Dec 20, 2023
1 parent 4e8fef8 commit 5da5ba6
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mmpose/models/utils/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,14 @@ def forward(self, x):
torch.Tensor: The tensor after applying scale norm.
"""

norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
if torch.onnx.is_in_onnx_export() and \
digit_version(TORCH_VERSION) >= digit_version('1.12'):

norm = torch.linalg.norm(x, dim=-1, keepdim=True)

else:
norm = torch.norm(x, dim=-1, keepdim=True)
norm = norm * self.scale
return x / norm.clamp(min=self.eps) * self.g


Expand Down

0 comments on commit 5da5ba6

Please sign in to comment.