Skip to content

Commit

Permalink
Overwrite __eq__ dunder
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 12, 2023
1 parent ab5023b commit e148da1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def __getitem__(self, idx) -> "CasadiLike":
"""Overrides get item operator"""
return CasadiLike(self.array[idx])

def __eq__(self, other: Union["CasadiLike", npt.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is not type(other):
return self.array == other
return self.array == other.array

@property
def T(self) -> "CasadiLike":
"""
Expand Down
6 changes: 6 additions & 0 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def __neg__(self) -> "JaxLike":
"""Overrides - operator"""
return JaxLike(-self.array)

def __eq__(self, other: Union["JaxLike", npt.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is not type(other):
return self.array.squeeze() == other.squeeze()
return self.array.squeeze() == other.array.squeeze()


class JaxLikeFactory(ArrayLikeFactory):
@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ def __neg__(self):
"""Overrides - operator"""
return NumpyLike(-self.array)

def __eq__(self, other: Union["NumpyLike", npt.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is type(other):
return self.array == other.array
else:
return self.array == other


class NumpyLikeFactory(ArrayLikeFactory):
@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def __neg__(self) -> "TorchLike":
"""Overrides - operator"""
return TorchLike(-self.array)

def __eq__(self, other: Union["TorchLike", ntp.ArrayLike]) -> bool:
"""Overrides == operator"""
if type(self) is type(other):
return self.array == other.array
else:
return self.array == other


class TorchLikeFactory(ArrayLikeFactory):
@staticmethod
Expand Down

0 comments on commit e148da1

Please sign in to comment.