From e148da1df5bdd302a4bed98add0a8b2d63403320 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 5 Dec 2023 11:32:15 +0100 Subject: [PATCH] Overwrite `__eq__` dunder --- src/adam/casadi/casadi_like.py | 6 ++++++ src/adam/jax/jax_like.py | 6 ++++++ src/adam/numpy/numpy_like.py | 7 +++++++ src/adam/pytorch/torch_like.py | 7 +++++++ 4 files changed, 26 insertions(+) diff --git a/src/adam/casadi/casadi_like.py b/src/adam/casadi/casadi_like.py index 9ee0ce0..8949e38 100644 --- a/src/adam/casadi/casadi_like.py +++ b/src/adam/casadi/casadi_like.py @@ -93,6 +93,12 @@ def __getitem__(self, idx) -> "CasadiLike": """Overrides get item operator""" return CasadiLike(self.array[idx]) + def __eq__(self, other: Union["CasadiLike", npt.ArrayLike]) -> bool: + """Overrides == operator""" + if type(self) is not type(other): + return self.array == other + return self.array == other.array + @property def T(self) -> "CasadiLike": """ diff --git a/src/adam/jax/jax_like.py b/src/adam/jax/jax_like.py index 5f898b7..06aff4b 100644 --- a/src/adam/jax/jax_like.py +++ b/src/adam/jax/jax_like.py @@ -107,6 +107,12 @@ def __neg__(self) -> "JaxLike": """Overrides - operator""" return JaxLike(-self.array) + def __eq__(self, other: Union["JaxLike", npt.ArrayLike]) -> bool: + """Overrides == operator""" + if type(self) is not type(other): + return self.array.squeeze() == other.squeeze() + return self.array.squeeze() == other.array.squeeze() + class JaxLikeFactory(ArrayLikeFactory): @staticmethod diff --git a/src/adam/numpy/numpy_like.py b/src/adam/numpy/numpy_like.py index 107624f..c0f26a3 100644 --- a/src/adam/numpy/numpy_like.py +++ b/src/adam/numpy/numpy_like.py @@ -106,6 +106,13 @@ def __neg__(self): """Overrides - operator""" return NumpyLike(-self.array) + def __eq__(self, other: Union["NumpyLike", npt.ArrayLike]) -> bool: + """Overrides == operator""" + if type(self) is type(other): + return self.array == other.array + else: + return self.array == other + class NumpyLikeFactory(ArrayLikeFactory): @staticmethod diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index 3241d1c..4753d45 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -112,6 +112,13 @@ def __neg__(self) -> "TorchLike": """Overrides - operator""" return TorchLike(-self.array) + def __eq__(self, other: Union["TorchLike", ntp.ArrayLike]) -> bool: + """Overrides == operator""" + if type(self) is type(other): + return self.array == other.array + else: + return self.array == other + class TorchLikeFactory(ArrayLikeFactory): @staticmethod