Skip to content

Commit

Permalink
Add operator_overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 2, 2024
1 parent 664691c commit 6a0f6ae
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ ignore_missing_imports = true
no_implicit_optional = true
check_untyped_defs = true
warn_unused_ignores = true
disallow_untyped_defs = true

[tool.pytest.ini_options]
# This will be pytest's future default.
Expand Down
10 changes: 5 additions & 5 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,23 @@ def _promote_target(obj: Union[Var, np.generic, int, float]) -> Optional[Var]:

return tuple(var for var in map(_promote_target, args))

def add(self, a, b) -> Var:
def add(self, a, b) -> Var: # type: ignore
a, b = self._promote(a, b)
return self.op.add(a, b)

def sub(self, a, b) -> Var:
def sub(self, a, b) -> Var: # type: ignore
a, b = self._promote(a, b)
return self.op.sub(a, b)

def mul(self, a, b) -> Var:
def mul(self, a, b) -> Var: # type: ignore
a, b = self._promote(a, b)
return self.op.mul(a, b)

def truediv(self, a, b) -> Var:
def truediv(self, a, b) -> Var: # type: ignore
a, b = self._promote(a, b, to_floating=True)
return self.op.div(a, b)

def floordiv(self, a, b) -> Var:
def floordiv(self, a, b) -> Var: # type: ignore
a, b = self._promote(a, b)
c = self.op.div(a, b)
if isinstance(c.type, Tensor) and not issubclass(c.type._elem_type, np.integer):
Expand Down
34 changes: 17 additions & 17 deletions src/spox/_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,61 +140,61 @@ def __copy__(self) -> "Var":
# during the build process
return self

def __deepcopy__(self, _) -> "Var":
def __deepcopy__(self, _: Any) -> "Var":
raise ValueError("'Var' objects cannot be deepcopied.")

def __add__(self, other) -> "Var":
def __add__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.add(self, other)

def __sub__(self, other) -> "Var":
def __sub__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.sub(self, other)

def __mul__(self, other) -> "Var":
def __mul__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.mul(self, other)

def __truediv__(self, other) -> "Var":
def __truediv__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.truediv(self, other)

def __floordiv__(self, other) -> "Var":
def __floordiv__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.floordiv(self, other)

def __neg__(self) -> "Var":
return Var._operator_dispatcher.neg(self)

def __and__(self, other) -> "Var":
def __and__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.and_(self, other)

def __or__(self, other) -> "Var":
def __or__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.or_(self, other)

def __xor__(self, other) -> "Var":
def __xor__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.xor(self, other)

def __invert__(self) -> "Var":
return Var._operator_dispatcher.not_(self)

def __radd__(self, other) -> "Var":
def __radd__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.add(other, self)

def __rsub__(self, other) -> "Var":
def __rsub__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.sub(other, self)

def __rmul__(self, other) -> "Var":
def __rmul__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.mul(other, self)

def __rtruediv__(self, other) -> "Var":
def __rtruediv__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.truediv(other, self)

def __rfloordiv__(self, other) -> "Var":
def __rfloordiv__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.floordiv(other, self)

def __rand__(self, other) -> "Var":
def __rand__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.and_(other, self)

def __ror__(self, other) -> "Var":
def __ror__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.or_(other, self)

def __rxor__(self, other) -> "Var":
def __rxor__(self, other) -> "Var": # type: ignore
return Var._operator_dispatcher.xor(other, self)


Expand Down

0 comments on commit 6a0f6ae

Please sign in to comment.