Skip to content

Commit

Permalink
Appease the linter
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Oct 18, 2023
1 parent 73c3224 commit 79beebb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
10 changes: 4 additions & 6 deletions src/spox/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,13 @@ def __repr__(self):

def __post_init__(self):
if any(not isinstance(var, Var) for var in self._results.values()):
raise TypeError(
f"Graph results must be Vars, not {set(type(obj) for obj in self._results.values()) - {Var}}."
)
seen_types = {type(obj) for obj in self._results.values()}
raise TypeError(f"Graph results must be Vars, not {seen_types - {Var}}.")
if self._arguments is not None and any(
not isinstance(var, Var) for var in self._arguments
):
raise TypeError(
f"Graph results must be Vars, not {set(type(obj) for obj in self._arguments) - {Var}}."
)
seen_types = {type(obj) for obj in self._arguments}
raise TypeError(f"Build outputs must be Vars, not {seen_types - {Var} }.")

def with_name(self, name: str) -> "Graph":
"""Return a Graph with its name set to ``name``."""
Expand Down
10 changes: 4 additions & 6 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,11 @@ def build(inputs: Dict[str, Var], outputs: Dict[str, Var]) -> onnx.ModelProto:
>>> model = build({'a': a, 'b': b, 'c': c}, {'r': q})
"""
if not all(isinstance(var, Var) for var in inputs.values()):
raise TypeError(
f"Build inputs must be Vars, not {set(type(obj) for obj in inputs.values()) - {Var} }."
)
seen_types = {type(obj) for obj in inputs.values()}
raise TypeError(f"Build inputs must be Vars, not {seen_types - {Var} }.")
if not all(isinstance(var, Var) for var in outputs.values()):
raise TypeError(
f"Build outputs must be Vars, not {set(type(obj) for obj in outputs.values()) - {Var} }."
)
seen_types = {type(obj) for obj in outputs.values()}
raise TypeError(f"Build outputs must be Vars, not {seen_types - {Var} }.")
if not all(isinstance(var._op, Argument) for var in inputs.values()):
raise TypeError(
"Build inputs must be `Var`s constructed using the `spox.argument` function. "
Expand Down

0 comments on commit 79beebb

Please sign in to comment.