From 9d78149975ede38eb1edfbef616b87ef41c2267a Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 22 Nov 2023 13:55:16 +0800 Subject: [PATCH] Fix pytorch warning when apply logical_xor on different shape tensor. --- tat/_utility.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tat/_utility.py b/tat/_utility.py index ac99945f1..05491dcc6 100644 --- a/tat/_utility.py +++ b/tat/_utility.py @@ -21,9 +21,9 @@ def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: if tensor_1.dtype is torch.bool: - return torch.logical_xor(tensor_1, tensor_2) + return tensor_1 ^ tensor_2 else: - return torch.add(tensor_1, tensor_2) + return tensor_1 + tensor_2 def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: