From c82f647352c5143e5dde49ac1dcfdb96a4e4d867 Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Fri, 19 Apr 2024 02:52:00 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 626305807 --- .../port/python/ydf/model/BUILD | 4 + .../port/python/ydf/model/export_jax.py | 77 ++++++++++- .../port/python/ydf/model/jax_model_test.py | 129 ++++++++++++++++++ 3 files changed, 209 insertions(+), 1 deletion(-) diff --git a/yggdrasil_decision_forests/port/python/ydf/model/BUILD b/yggdrasil_decision_forests/port/python/ydf/model/BUILD index 29d678cc..d89692f5 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/BUILD +++ b/yggdrasil_decision_forests/port/python/ydf/model/BUILD @@ -292,5 +292,9 @@ py_test( # absl/testing:parameterized dep, # jax dep, # numpy dep, + "@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_py_proto", + "//ydf/dataset:dataspec", + "//ydf/learner:generic_learner", + "//ydf/learner:specialized_learners", ], ) diff --git a/yggdrasil_decision_forests/port/python/ydf/model/export_jax.py b/yggdrasil_decision_forests/port/python/ydf/model/export_jax.py index ad165315..f5c71334 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/export_jax.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/export_jax.py @@ -14,7 +14,12 @@ """Utilities to export JAX models.""" -from typing import Any, Sequence +import dataclasses +from typing import Any, Sequence, Dict, Optional + +from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb +from ydf.dataset import dataspec as dataspec_lib +from ydf.model import generic_model # pytype: disable=import-error # pylint: disable=g-import-not-at-top @@ -61,3 +66,73 @@ def to_compact_jax_array(values: Sequence[int]) -> jax.Array: """Converts a list of integers to a compact Jax array.""" return jnp.asarray(values, dtype=compact_dtype(values)) + + +@dataclasses.dataclass +class FeatureEncoding: + """Utility to prepare feature values before being fed into the Jax model. + + Does the following: + - Encodes categorical strings into categorical integers. + + Attributes: + categorical: Mapping between categorical-string feature to the dictionary of + categorical-string value to categorical-integer value. + categorical_out_of_vocab_item: Integer value representing an out of + vocabulary item. + """ + + categorical: Dict[str, Dict[str, int]] + categorical_out_of_vocab_item: int = 0 + + @classmethod + def build( + cls, + input_features: Sequence[generic_model.InputFeature], + dataspec: ds_pb.DataSpecification, + ) -> Optional["FeatureEncoding"]: + """Creates a FeatureEncoding object. + + If the input feature does not require feature encoding, returns None. + + Args: + input_features: All the input features of a model. + dataspec: Dataspec of the model. + + Returns: + A FeatureEncoding or None. + """ + + categorical = {} + for input_feature in input_features: + column_spec = dataspec.columns[input_feature.column_idx] + if ( + input_feature.semantic + in [ + dataspec_lib.Semantic.CATEGORICAL, + dataspec_lib.Semantic.CATEGORICAL_SET, + ] + and not column_spec.categorical.is_already_integerized + ): + categorical[input_feature.name] = { + key: item.index + for key, item in column_spec.categorical.items.items() + } + if not categorical: + return None + return FeatureEncoding(categorical=categorical) + + def encode(self, feature_values: Dict[str, Any]) -> Dict[str, jax.Array]: + """Encodes feature values for a model.""" + + def encode_item(key: str, value: Any) -> jax.Array: + categorical_map = self.categorical.get(key) + if categorical_map is not None: + # Categorical string encoding. + value = [ + categorical_map.get(x, self.categorical_out_of_vocab_item) + for x in value + ] + return jax.numpy.asarray(value) + + return {k: encode_item(k, v) for k, v in feature_values.items()} diff --git a/yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py b/yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py index 1572e638..2af534c3 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py @@ -12,15 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List from absl.testing import absltest from absl.testing import parameterized +import jax import jax.numpy as jnp import numpy as np +from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb +from ydf.dataset import dataspec as dataspec_lib +from ydf.learner import specialized_learners from ydf.model import export_jax as to_jax +from ydf.model import generic_model class JaxModelTest(parameterized.TestCase): + def create_dataset(self, columns: List[str]) -> Dict[str, Any]: + """Creates a dataset with random values.""" + data = { + # Single-dim features + "f1": np.random.random(size=100), + "f2": np.random.random(size=100), + "i1": np.random.randint(100, size=100), + "i2": np.random.randint(100, size=100), + "c1": np.random.choice(["x", "y", "z"], size=100, p=[0.6, 0.3, 0.1]), + "b1": np.random.randint(2, size=100).astype(np.bool_), + "b2": np.random.randint(2, size=100).astype(np.bool_), + # Cat-set features + "cs1": [[], ["a", "b", "c"], ["b", "c"], ["a"]] * 25, + # Multi-dim features + "multi_f1": np.random.random(size=(100, 5)), + "multi_f2": np.random.random(size=(100, 5)), + "multi_i1": np.random.randint(100, size=(100, 5)), + "multi_c1": np.random.choice(["x", "y", "z"], size=(100, 5)), + "multi_b1": np.random.randint(2, size=(100, 5)).astype(np.bool_), + # Labels + "label_class_binary": np.random.choice(["l1", "l2"], size=100), + "label_class_multi": np.random.choice(["l1", "l2", "l3"], size=100), + "label_regress": np.random.random(size=100), + } + return {k: data[k] for k in columns} + @parameterized.parameters( ((0,), jnp.int8), ((0, 1, -1), jnp.int8), @@ -45,6 +77,103 @@ def test_compact_dtype_non_supported(self): with self.assertRaisesRegex(ValueError, "No supported compact dtype"): to_jax.compact_dtype((0x80000000,)) + def test_feature_encoding_basic(self): + feature_encoding = to_jax.FeatureEncoding.build( + [ + generic_model.InputFeature( + "f1", dataspec_lib.Semantic.NUMERICAL, 0 + ), + generic_model.InputFeature( + "f2", dataspec_lib.Semantic.CATEGORICAL, 1 + ), + generic_model.InputFeature( + "f3", dataspec_lib.Semantic.CATEGORICAL, 2 + ), + ], + ds_pb.DataSpecification( + created_num_rows=3, + columns=( + ds_pb.Column( + name="f1", + type=ds_pb.ColumnType.NUMERICAL, + ), + ds_pb.Column( + name="f2", + type=ds_pb.ColumnType.CATEGORICAL, + categorical=ds_pb.CategoricalSpec( + items={ + "": ds_pb.CategoricalSpec.VocabValue(index=0), + "A": ds_pb.CategoricalSpec.VocabValue(index=1), + "B": ds_pb.CategoricalSpec.VocabValue(index=2), + }, + ), + ), + ds_pb.Column( + name="f3", + type=ds_pb.ColumnType.CATEGORICAL, + categorical=ds_pb.CategoricalSpec( + is_already_integerized=True, + ), + ), + ds_pb.Column( + name="f4", + type=ds_pb.ColumnType.CATEGORICAL, + categorical=ds_pb.CategoricalSpec( + items={ + "": ds_pb.CategoricalSpec.VocabValue(index=0), + "X": ds_pb.CategoricalSpec.VocabValue(index=1), + "Y": ds_pb.CategoricalSpec.VocabValue(index=2), + }, + ), + ), + ), + ), + ) + self.assertIsNotNone(feature_encoding) + self.assertDictEqual( + feature_encoding.categorical, {"f2": {"": 0, "A": 1, "B": 2}} + ) + + def test_feature_encoding_on_model(self): + columns = ["f1", "i1", "c1", "b1", "cs1", "label_class_binary"] + model = specialized_learners.RandomForestLearner( + label="label_class_binary", + num_trees=2, + features=[("cs1", dataspec_lib.Semantic.CATEGORICAL_SET)], + include_all_columns=True, + ).train(self.create_dataset(columns)) + feature_encoding = to_jax.FeatureEncoding.build( + model.input_features(), model.data_spec() + ) + self.assertIsNotNone(feature_encoding) + self.assertDictEqual( + feature_encoding.categorical, + { + "c1": {"": 0, "x": 1, "y": 2, "z": 3}, + "cs1": {"": 0, "a": 1, "b": 2, "c": 3}, + }, + ) + + encoded_features = feature_encoding.encode( + {"f1": [1, 2, 3], "c1": ["x", "y", "other"]} + ) + np.testing.assert_array_equal( + encoded_features["f1"], jax.numpy.asarray([1, 2, 3]) + ) + np.testing.assert_array_equal( + encoded_features["c1"], jax.numpy.asarray([1, 2, 0]) + ) + + def test_feature_encoding_is_none(self): + columns = ["f1", "i1", "label_class_binary"] + model = specialized_learners.RandomForestLearner( + label="label_class_binary", num_trees=2 + ).train(self.create_dataset(columns)) + feature_encoding = to_jax.FeatureEncoding.build( + model.input_features(), model.data_spec() + ) + self.assertIsNone(feature_encoding) + if __name__ == "__main__": absltest.main()