diff --git a/src/spox/_fields.py b/src/spox/_fields.py index 0b3c96d..ab6f114 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -8,10 +8,11 @@ import warnings from collections.abc import Iterable, Iterator, Sequence from dataclasses import Field, dataclass -from typing import Optional, Union, cast, get_type_hints +from typing import Optional, Union, get_type_hints from ._attributes import Attr from ._exceptions import InferenceWarning +from ._type_system import Optional as tOptional from ._value_prop import PropDict, PropValue from ._var import Var, _VarInfo @@ -152,50 +153,17 @@ def fully_typed(self) -> bool: for var in self.get_var_infos().values() ) - -@dataclass -class BaseInputs(BaseVarInfos): - def vars(self, prop_values: PropDict) -> BaseVars: - vars_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {} - - for field in dataclasses.fields(self): - field_type = self._get_field_type(field) - field_value = getattr(self, field.name) - - if field_type == VarFieldKind.SINGLE: - prop_value = cast(PropValue, prop_values.get(field.name, None)) - vars_dict[field.name] = Var(field_value, prop_value) - - elif field_type == VarFieldKind.OPTIONAL: - prop_value = cast(PropValue, prop_values.get(field.name, None)) - vars_dict[field.name] = Var(field_value, prop_value) - - elif field_type == VarFieldKind.VARIADIC: - vars = [] - - for i, var_info in enumerate(field_value): - var_value = prop_values.get(f"{field.name}_{i}", None) - assert isinstance(var_value, PropValue) - vars.append(Var(var_info, var_value)) - - vars_dict[field.name] = vars - - return BaseVars(vars_dict) - - -@dataclass -class BaseOutputs(BaseVarInfos): - def _propagate_vars(self, prop_values: Optional[PropDict] = None) -> BaseVars: - if prop_values is None: - prop_values = {} - + def into_vars(self, prop_values: PropDict) -> BaseVars: def _create_var(key: str, var_info: _VarInfo) -> Var: ret = Var(var_info, None) if var_info.type is None or key not in prop_values: return ret - prop = PropValue(var_info.type, prop_values.get(key)) + if not isinstance(var_info.type, tOptional) and prop_values[key] is None: + return ret + + prop = PropValue(var_info.type, prop_values[key]) if prop.check(): ret._value = prop else: @@ -219,3 +187,13 @@ def _create_var(key: str, var_info: _VarInfo) -> Var: ] return BaseVars(ret_dict) + + +@dataclass +class BaseInputs(BaseVarInfos): + pass + + +@dataclass +class BaseOutputs(BaseVarInfos): + pass diff --git a/src/spox/_function.py b/src/spox/_function.py index d92a496..96a3127 100644 --- a/src/spox/_function.py +++ b/src/spox/_function.py @@ -86,10 +86,10 @@ def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: self.func_inputs = self.Inputs(**self.func_args) self.func_outputs = self.constructor( - self.func_attrs, self.func_inputs.vars(input_prop_values) + self.func_attrs, self.func_inputs.into_vars(input_prop_values) ) self.func_graph = _graph.results( - **self.func_outputs._propagate_vars(input_prop_values).flatten_vars() + **self.func_outputs.into_vars(input_prop_values).flatten_vars() ).with_arguments(*func_args_var.values()) return { diff --git a/src/spox/_node.py b/src/spox/_node.py index 35df2f6..8e2a022 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -267,7 +267,7 @@ def get_output_vars( self.validate_types() out_values = self.propagate_values(input_prop_values) - return self.outputs._propagate_vars(out_values) + return self.outputs.into_vars(out_values) def validate_types(self) -> None: """Validation of types, ran at the end of Node creation."""