Skip to content

Commit

Permalink
Integer conversion utility functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625674319
  • Loading branch information
achoum authored and copybara-github committed Apr 17, 2024
1 parent e69c495 commit e38179c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
18 changes: 18 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pybind_library(
py_library(
name = "generic_model",
srcs = [
"export_jax.py",
"export_tf.py",
"generic_model.py",
],
Expand Down Expand Up @@ -276,3 +277,20 @@ py_test(
"@ydf_cc//yggdrasil_decision_forests/model:hyperparameter_py_proto",
],
)

py_test(
name = "jax_model_test",
srcs = ["jax_model_test.py"],
data = [
"//test_data",
"@ydf_cc//yggdrasil_decision_forests/test_data",
],
python_version = "PY3",
deps = [
":generic_model",
# absl/testing:absltest dep,
# absl/testing:parameterized dep,
# jax dep,
# numpy dep,
],
)
63 changes: 63 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/export_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to export JAX models."""

from typing import Any, Sequence

# pytype: disable=import-error
# pylint: disable=g-import-not-at-top
try:
import jax.numpy as jnp
import jax
except ImportError as exc:
raise ImportError(
"JAX is needed for this operation. Install JAX following"
" https://jax.readthedocs.io/en/latest/installation.html and try again."
) from exc
# pylint: enable=g-import-not-at-top
# pytype: enable=import-error


def compact_dtype(values: Sequence[int]) -> Any:
"""Selects the most compact dtype to represent a list of signed integers.
Only supports: int{8, 16, 32}.
Note: Jax operations between unsigned and signed integers can be expensive.
Args:
values: List of integer values.
Returns:
Dtype compatible with all the values.
"""

if not values:
raise ValueError("No values provided")

min_value = min(values)
max_value = max(values)

for candidate in [jnp.int8, jnp.int16, jnp.int32]:
info = jnp.iinfo(candidate)
if min_value >= info.min and max_value <= info.max:
return candidate
raise ValueError("No supported compact dtype")


def to_compact_jax_array(values: Sequence[int]) -> jax.Array:
"""Converts a list of integers to a compact Jax array."""

return jnp.asarray(values, dtype=compact_dtype(values))
50 changes: 50 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
from ydf.model import export_jax as to_jax


class JaxModelTest(parameterized.TestCase):

@parameterized.parameters(
((0,), jnp.int8),
((0, 1, -1), jnp.int8),
((0, 1, 0x7F, -0x80), jnp.int8),
((0, 1, 0x7F + 1), jnp.int16),
((0, 1, -0x80 - 1), jnp.int16),
((0, 1, 0x7FFF), jnp.int16),
((0, 1, -0x8000), jnp.int16),
((0, 1, 0x7FFF + 1), jnp.int32),
((0, 1, -0x8000 - 1), jnp.int32),
((0, 1, 0x7FFFFFFF), jnp.int32),
((0, 1, -0x80000000), jnp.int32),
)
def test_compact_dtype(self, values, expected_dtype):
self.assertEqual(to_jax.compact_dtype(values), expected_dtype)

jax_array = to_jax.to_compact_jax_array(values)
self.assertEqual(jax_array.dtype.type, expected_dtype)
np.testing.assert_array_equal(jax_array, jnp.array(values, expected_dtype))

def test_compact_dtype_non_supported(self):
with self.assertRaisesRegex(ValueError, "No supported compact dtype"):
to_jax.compact_dtype((0x80000000,))


if __name__ == "__main__":
absltest.main()

0 comments on commit e38179c

Please sign in to comment.