Skip to content

Commit

Permalink
Unify logic around VarInfos -> Var
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 9, 2024
1 parent ac8807e commit e5c311f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 42 deletions.
56 changes: 17 additions & 39 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import warnings
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import Field, dataclass
from typing import Optional, Union, cast, get_type_hints
from typing import Optional, Union, get_type_hints

from ._attributes import Attr
from ._exceptions import InferenceWarning
from ._type_system import Optional as tOptional
from ._value_prop import PropDict, PropValue
from ._var import Var, _VarInfo

Expand Down Expand Up @@ -152,50 +153,17 @@ def fully_typed(self) -> bool:
for var in self.get_var_infos().values()
)


@dataclass
class BaseInputs(BaseVarInfos):
def vars(self, prop_values: PropDict) -> BaseVars:
vars_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}

for field in dataclasses.fields(self):
field_type = self._get_field_type(field)
field_value = getattr(self, field.name)

if field_type == VarFieldKind.SINGLE:
prop_value = cast(PropValue, prop_values.get(field.name, None))
vars_dict[field.name] = Var(field_value, prop_value)

elif field_type == VarFieldKind.OPTIONAL:
prop_value = cast(PropValue, prop_values.get(field.name, None))
vars_dict[field.name] = Var(field_value, prop_value)

elif field_type == VarFieldKind.VARIADIC:
vars = []

for i, var_info in enumerate(field_value):
var_value = prop_values.get(f"{field.name}_{i}", None)
assert isinstance(var_value, PropValue)
vars.append(Var(var_info, var_value))

vars_dict[field.name] = vars

return BaseVars(vars_dict)


@dataclass
class BaseOutputs(BaseVarInfos):
def _propagate_vars(self, prop_values: Optional[PropDict] = None) -> BaseVars:
if prop_values is None:
prop_values = {}

def into_vars(self, prop_values: PropDict) -> BaseVars:
def _create_var(key: str, var_info: _VarInfo) -> Var:
ret = Var(var_info, None)

if var_info.type is None or key not in prop_values:
return ret

prop = PropValue(var_info.type, prop_values.get(key))
if not isinstance(var_info.type, tOptional) and prop_values[key] is None:
return ret

prop = PropValue(var_info.type, prop_values[key])
if prop.check():
ret._value = prop
else:
Expand All @@ -219,3 +187,13 @@ def _create_var(key: str, var_info: _VarInfo) -> Var:
]

return BaseVars(ret_dict)


@dataclass
class BaseInputs(BaseVarInfos):
pass


@dataclass
class BaseOutputs(BaseVarInfos):
pass
4 changes: 2 additions & 2 deletions src/spox/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:

self.func_inputs = self.Inputs(**self.func_args)
self.func_outputs = self.constructor(
self.func_attrs, self.func_inputs.vars(input_prop_values)
self.func_attrs, self.func_inputs.into_vars(input_prop_values)
)
self.func_graph = _graph.results(
**self.func_outputs._propagate_vars(input_prop_values).flatten_vars()
**self.func_outputs.into_vars(input_prop_values).flatten_vars()
).with_arguments(*func_args_var.values())

return {
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_output_vars(
self.validate_types()

out_values = self.propagate_values(input_prop_values)
return self.outputs._propagate_vars(out_values)
return self.outputs.into_vars(out_values)

def validate_types(self) -> None:
"""Validation of types, ran at the end of Node creation."""
Expand Down

0 comments on commit e5c311f

Please sign in to comment.