Skip to content

Commit

Permalink
Rename "encode" -> "encoder".
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631434141
  • Loading branch information
achoum authored and copybara-github committed May 7, 2024
1 parent aff0e61 commit 349bfab
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 45 deletions.
5 changes: 3 additions & 2 deletions yggdrasil_decision_forests/port/python/ydf/model/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def run_jax_engine(
with jax.default_device(device):

jax_model = model.to_jax_function()
if jax_model.encode is not None:
input_values = jax_model.encode(dataset_without_labels)
if jax_model.encoder is not None:
input_values = jax_model.encoder(dataset_without_labels)
else:
input_values = np_dict_to_jax_dict(dataset_without_labels)

Expand Down Expand Up @@ -513,6 +513,7 @@ def run_preconfigured(
Args:
profiler: If true, enables the profiler. See RunConfiguration.profiler.
show_logs: Prints logs about the benchmark during its execution.
models: Set of model names to run. If None, runs all the avilable models.
Returns:
Expand Down
36 changes: 15 additions & 21 deletions yggdrasil_decision_forests/port/python/ydf/model/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def to_compact_jax_array(values: Sequence[int]) -> jax.Array:
return jnp.asarray(values, dtype=compact_dtype(values))


# TODO: Rename to "FeatureEncoder".
@dataclasses.dataclass
class FeatureEncoding:
class FeatureEncoder:
"""Utility to prepare feature values before being fed into the Jax model.
Does the following:
Expand All @@ -115,8 +114,8 @@ def build(
cls,
input_features: Sequence[generic_model.InputFeature],
dataspec: ds_pb.DataSpecification,
) -> Optional["FeatureEncoding"]:
"""Creates a FeatureEncoding object.
) -> Optional["FeatureEncoder"]:
"""Creates a FeatureEncoder object.
If the input feature does not require feature encoding, returns None.
Expand All @@ -125,7 +124,7 @@ def build(
dataspec: Dataspec of the model.
Returns:
A FeatureEncoding or None.
A FeatureEncoder or None.
"""

categorical = {}
Expand All @@ -145,7 +144,7 @@ def build(
}
if not categorical:
return None
return FeatureEncoding(categorical=categorical)
return FeatureEncoder(categorical=categorical)

def __call__(self, feature_values: Dict[str, Any]) -> Dict[str, jax.Array]:
"""Alias for "encode"."""
Expand Down Expand Up @@ -175,15 +174,15 @@ class JaxModel:
predict: Jitted JAX function that computes the model predictions. The
signature is `predict(feature_values)` if `params` is None, and
`predict(feature_values, params)` if `params` is set.
encode: Optional object to encode features before the JAX model. Is None if
encoder: Optional object to encode features before the JAX model. Is None if
the model does not need special feature encoding. For instance, used to
encode categorical string values.
params: Learnable parameters of the model. If set, "params" should be passed
as an argument to the "predict" function.
"""

predict: Union[Callable[[Any], Any], Callable[[Any, Dict[str, Any]], Any]]
encode: Optional[FeatureEncoding] # TODO: Rename to "encoder".
encoder: Optional[FeatureEncoder]
params: Optional[Dict[str, Any]]


Expand Down Expand Up @@ -422,7 +421,7 @@ class InternalForest:
Attributes:
model: Input decision forest model.
feature_spec: Internal feature indexing.
feature_encoding: How to encode features before feeding them to the model.
feature_encoder: How to encode features before feeding them to the model.
dataspec: Dataspec.
leaf_outputs: Prediction values for each leaf node.
split_features: Internal idx of the feature being tested for each non-leaf
Expand Down Expand Up @@ -451,7 +450,7 @@ class InternalForest:

model: dataclasses.InitVar[generic_model.GenericModel]
feature_spec: InternalFeatureSpec = dataclasses.field(init=False)
feature_encoding: Optional[FeatureEncoding] = dataclasses.field(init=False)
feature_encoder: Optional[FeatureEncoder] = dataclasses.field(init=False)
dataspec: ds_pb.DataSpecification = dataclasses.field(repr=False, init=False)
leaf_outputs: ArrayFloat = dataclasses.field(
default_factory=lambda: array.array("f", [])
Expand Down Expand Up @@ -509,7 +508,7 @@ def __post_init__(self, model: generic_model.GenericModel):

input_features = model.input_features()
self.dataspec = model.data_spec()
self.feature_encoding = FeatureEncoding.build(input_features, self.dataspec)
self.feature_encoder = FeatureEncoder.build(input_features, self.dataspec)
self.feature_spec = InternalFeatureSpec(input_features, self.dataspec)

if isinstance(
Expand Down Expand Up @@ -720,17 +719,12 @@ def to_jax_function(
Args:
model: A YDF model.
jit: If true, compiles the function with @jax.jit.
apply_activation: Should the activation function, if any, be applied on the
model output.
leaves_as_params: If true, exports the leaf values as learnable parameters.
In this case, `params` is set in the returned value, and it should be
passed to `predict(feature_values, params)`.
jit: See "to_jax_function" in generic_model.py.
apply_activation: See "to_jax_function" in generic_model.py.
leaves_as_params: See "to_jax_function" in generic_model.py.
Returns:
A Jax function and optionally a FeatureEncoding object to encode
features. If the model does not need any special feature
encoding, the second returned value is None.
See "to_jax_function" in generic_model.py.
"""

# TODO: Add support for Random Forest models.
Expand Down Expand Up @@ -794,7 +788,7 @@ def to_jax_function(
predict = jax.jit(predict)
return JaxModel(
predict=predict,
encode=forest.feature_encoding,
encoder=forest.feature_encoder,
params=params if params else None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,15 +843,17 @@ def to_jax_function( # pytype: disable=name-error
})
# Convert model to a JAX function.
jax_model, feature_encoding = model.o_jax_function()
jax_model = model.o_jax_function()
# Make predictions with the TF module.
jax_predictions = jax_model({
jax_predictions = jax_model.predict({
"f1": jnp.array([0, 0.5, 1]),
"f2": jnp.array([1, 0, 0.5]),
})
```
TODO: Document the encoder and jax params.
Args:
jit: If true, compiles the function with @jax.jit.
apply_activation: Should the activation function, if any, be applied on
Expand All @@ -861,9 +863,9 @@ def to_jax_function( # pytype: disable=name-error
should be passed to `predict(feature_values, params)`.
Returns:
A Jax function and optionnaly a FeatureEncoding object to encode
features. If the model does not need any special feature
encoding, the second returned value is None.
A dataclass containing the JAX prediction function (`predict`) and
optionnaly the model parameteres (`params`) and feature encoder
(`encoder`).
"""

return _get_export_jax().to_jax_function(
Expand Down
34 changes: 17 additions & 17 deletions yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def test_compact_dtype_non_supported(self):
with self.assertRaisesRegex(ValueError, "No supported compact dtype"):
to_jax.compact_dtype((0x80000000,))

def test_feature_encoding_basic(self):
feature_encoding = to_jax.FeatureEncoding.build(
def test_feature_encoder_basic(self):
feature_encoder = to_jax.FeatureEncoder.build(
[
generic_model.InputFeature(
"f1", dataspec_lib.Semantic.NUMERICAL, 0
Expand Down Expand Up @@ -173,32 +173,32 @@ def test_feature_encoding_basic(self):
),
),
)
self.assertIsNotNone(feature_encoding)
self.assertIsNotNone(feature_encoder)
self.assertDictEqual(
feature_encoding.categorical, {"f2": {"<OOD>": 0, "A": 1, "B": 2}}
feature_encoder.categorical, {"f2": {"<OOD>": 0, "A": 1, "B": 2}}
)

def test_feature_encoding_on_model(self):
def test_feature_encoder_on_model(self):
columns = ["f1", "i1", "c1", "b1", "cs1", "label_class_binary"]
model = specialized_learners.RandomForestLearner(
label="label_class_binary",
num_trees=2,
features=[("cs1", dataspec_lib.Semantic.CATEGORICAL_SET)],
include_all_columns=True,
).train(create_dataset(columns))
feature_encoding = to_jax.FeatureEncoding.build(
feature_encoder = to_jax.FeatureEncoder.build(
model.input_features(), model.data_spec()
)
self.assertIsNotNone(feature_encoding)
self.assertIsNotNone(feature_encoder)
self.assertDictEqual(
feature_encoding.categorical,
feature_encoder.categorical,
{
"c1": {"<OOD>": 0, "x": 1, "y": 2, "z": 3},
"cs1": {"<OOD>": 0, "a": 1, "b": 2, "c": 3},
},
)

encoded_features = feature_encoding.encode(
encoded_features = feature_encoder.encode(
{"f1": [1, 2, 3], "c1": ["x", "y", "other"]}
)
np.testing.assert_array_equal(
Expand All @@ -208,15 +208,15 @@ def test_feature_encoding_on_model(self):
encoded_features["c1"], jnp.asarray([1, 2, 0])
)

def test_feature_encoding_is_none(self):
def test_feature_encoder_is_none(self):
columns = ["f1", "i1", "label_class_binary"]
model = specialized_learners.RandomForestLearner(
label="label_class_binary", num_trees=2
).train(create_dataset(columns))
feature_encoding = to_jax.FeatureEncoding.build(
feature_encoder = to_jax.FeatureEncoder.build(
model.input_features(), model.data_spec()
)
self.assertIsNone(feature_encoding)
self.assertIsNone(feature_encoder)


class InternalFeatureSpecTest(parameterized.TestCase):
Expand Down Expand Up @@ -588,7 +588,7 @@ def test_internal_forest_on_manual(self):
)

self.assertEqual(internal_forest.num_trees(), 2)
self.assertIsNotNone(internal_forest.feature_encoding)
self.assertIsNotNone(internal_forest.feature_encoder)

self.assertEqual(
internal_forest.leaf_outputs,
Expand Down Expand Up @@ -658,7 +658,7 @@ def test_internal_forest_on_model(self):

internal_forest = to_jax.InternalForest(model)
self.assertEqual(internal_forest.num_trees(), 10)
self.assertIsNotNone(internal_forest.feature_encoding)
self.assertIsNotNone(internal_forest.feature_encoder)


class ToJaxTest(parameterized.TestCase):
Expand Down Expand Up @@ -747,12 +747,12 @@ def test_to_jax_function(

# Convert model to tf function
jax_model = to_jax.to_jax_function(model)
assert (jax_model.encode is not None) == has_encoding
assert (jax_model.encoder is not None) == has_encoding

# Generate Jax predictions
del test_ds[label]
if jax_model.encode is not None:
input_values = jax_model.encode(test_ds)
if jax_model.encoder is not None:
input_values = jax_model.encoder(test_ds)
else:
input_values = {
k: jnp.asarray(v) for k, v in test_ds.items() if k != label
Expand Down

0 comments on commit 349bfab

Please sign in to comment.