Skip to content

Commit

Permalink
Add support for multi-dim features in YDF to JAX
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631432510
  • Loading branch information
achoum authored and copybara-github committed May 7, 2024
1 parent ce7dcd1 commit aff0e61
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 50 deletions.
122 changes: 93 additions & 29 deletions yggdrasil_decision_forests/port/python/ydf/model/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,53 +200,101 @@ class InternalFeatureValues:
boolean: jax.Array


@dataclasses.dataclass(frozen=True)
class InternalFeatureItem:
"""A single feature in InternalFeatureSpec.
Attributes:
name: Name of the feature.
dim: Number of dimensions of the feature.
"""

name: str
dim: int


@dataclasses.dataclass
class InternalFeatureSpec:
"""Spec of the internal feature value representation.
Attributes:
input_features: Input features of the model.
numerical: Name of numerical features in internal order.
categorical: Name of categorical features in internal order.
boolean: Name of boolean features in internal order.
dataspec: Dataspec of the model.
numerical: Name and size of numerical features in internal order.
categorical: Name and size of categorical features in internal order.
boolean: Name and size of boolean features in internal order.
inv_numerical: Column idx to internal idx mapping for numerical features.
inv_categorical: Column idx to internal idx mapping for categorical features
inv_boolean: Column idx to internal idx mapping for boolean features.
feature_names: Name of all the input features.
"""

input_features: dataclasses.InitVar[Sequence[generic_model.InputFeature]]
dataspec: dataclasses.InitVar[ds_pb.DataSpecification]

numerical: List[str] = dataclasses.field(default_factory=list)
categorical: List[str] = dataclasses.field(default_factory=list)
boolean: List[str] = dataclasses.field(default_factory=list)
numerical: List[InternalFeatureItem] = dataclasses.field(default_factory=list)
categorical: List[InternalFeatureItem] = dataclasses.field(
default_factory=list
)
boolean: List[InternalFeatureItem] = dataclasses.field(default_factory=list)

inv_numerical: Dict[int, int] = dataclasses.field(default_factory=dict)
inv_categorical: Dict[int, int] = dataclasses.field(default_factory=dict)
inv_boolean: Dict[int, int] = dataclasses.field(default_factory=dict)

feature_names: Set[str] = dataclasses.field(default_factory=set)

def __post_init__(self, input_features: Sequence[generic_model.InputFeature]):
for input_feature in input_features:
self.feature_names.add(input_feature.name)
if input_feature.semantic == dataspec_lib.Semantic.NUMERICAL:
self.inv_numerical[input_feature.column_idx] = len(self.numerical)
self.numerical.append(input_feature.name)

elif input_feature.semantic == dataspec_lib.Semantic.CATEGORICAL:
self.inv_categorical[input_feature.column_idx] = len(self.categorical)
self.categorical.append(input_feature.name)

elif input_feature.semantic == dataspec_lib.Semantic.BOOLEAN:
self.inv_boolean[input_feature.column_idx] = len(self.boolean)
self.boolean.append(input_feature.name)
def __post_init__(
self,
input_features: Sequence[generic_model.InputFeature],
dataspec: ds_pb.DataSpecification,
):

def add_feature(
name: str, begin_column_idx: int, size: int, semantic: ds_pb.ColumnType
):
"""Adds a new feature."""
self.feature_names.add(name)
if semantic == ds_pb.ColumnType.NUMERICAL:
target_inv = self.inv_numerical
target_feature = self.numerical
elif semantic == ds_pb.ColumnType.CATEGORICAL:
target_inv = self.inv_categorical
target_feature = self.categorical
elif semantic == ds_pb.ColumnType.BOOLEAN:
target_inv = self.inv_boolean
target_feature = self.boolean
else:
raise ValueError(
f"The semantic of feature {input_feature} is not supported by the"
" YDF to Jax exporter"
f"The semantic of feature {name} is not supported by the YDF to Jax"
" exporter"
)
for dim_idx in range(size):
target_inv[begin_column_idx + dim_idx] = len(target_inv)
target_feature.append(InternalFeatureItem(name, size))

# Multi-dim features
for unstacked in dataspec.unstackeds:
if unstacked.size == 0:
raise RuntimeError("Empty unstacked")
add_feature(
unstacked.original_name,
unstacked.begin_column_idx,
unstacked.size,
dataspec.columns[unstacked.begin_column_idx].type,
)

# Single dim features
for input_feature in input_features:
if dataspec.columns[input_feature.column_idx].is_unstacked:
# Already processed
continue
add_feature(
input_feature.name,
input_feature.column_idx,
1,
input_feature.semantic.to_proto_type(),
)

def convert_features(
self, feature_values: Dict[str, jax.Array]
Expand All @@ -267,15 +315,31 @@ def convert_features(

if set(feature_values) != self.feature_names:
raise ValueError(
f"Expecting values with keys {set(self.feature_names)!r}. Got"
f" {set(feature_values.keys())!r}"
"Expecting dictionary of values with keys"
f" {set(self.feature_names)!r}. Got {set(feature_values.keys())!r}"
)

def stack(features, dtype):
def normalize_feature(feature_value, feature: InternalFeatureItem):
if len(feature_value.shape) == 1:
feature_value = jnp.expand_dims(feature_value, axis=1)
elif len(feature_value.shape) != 2:
raise ValueError("Featire value must be 1- or 2-dimensional")
if feature_value.shape[1] != feature.dim:
raise ValueError(
f"Expecting dimension {feature.dim} for feature {feature.name!r}."
f" Got {feature_value.shape[1]!r}"
)
return feature_value

def stack(features: List[InternalFeatureItem], dtype):
if not features:
return jnp.zeros(shape=[batch_size, 0], dtype=dtype)
return jnp.stack(
[feature_values[feature] for feature in features],

return jnp.concatenate(
[
normalize_feature(feature_values[feature.name], feature)
for feature in features
],
dtype=dtype,
axis=1,
)
Expand Down Expand Up @@ -388,7 +452,7 @@ class InternalForest:
model: dataclasses.InitVar[generic_model.GenericModel]
feature_spec: InternalFeatureSpec = dataclasses.field(init=False)
feature_encoding: Optional[FeatureEncoding] = dataclasses.field(init=False)
dataspec: Any = dataclasses.field(repr=False, init=False)
dataspec: ds_pb.DataSpecification = dataclasses.field(repr=False, init=False)
leaf_outputs: ArrayFloat = dataclasses.field(
default_factory=lambda: array.array("f", [])
)
Expand Down Expand Up @@ -446,7 +510,7 @@ def __post_init__(self, model: generic_model.GenericModel):
input_features = model.input_features()
self.dataspec = model.data_spec()
self.feature_encoding = FeatureEncoding.build(input_features, self.dataspec)
self.feature_spec = InternalFeatureSpec(input_features)
self.feature_spec = InternalFeatureSpec(input_features, self.dataspec)

if isinstance(
model, gradient_boosted_trees_model.GradientBoostedTreesModel
Expand Down
Loading

0 comments on commit aff0e61

Please sign in to comment.