-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integer conversion utility functions.
PiperOrigin-RevId: 625674319
- Loading branch information
1 parent
e69c495
commit e38179c
Showing
3 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
yggdrasil_decision_forests/port/python/ydf/model/export_jax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2022 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utilities to export JAX models.""" | ||
|
||
from typing import Any, Sequence | ||
|
||
# pytype: disable=import-error | ||
# pylint: disable=g-import-not-at-top | ||
try: | ||
import jax.numpy as jnp | ||
import jax | ||
except ImportError as exc: | ||
raise ImportError( | ||
"JAX is needed for this operation. Install JAX following" | ||
" https://jax.readthedocs.io/en/latest/installation.html and try again." | ||
) from exc | ||
# pylint: enable=g-import-not-at-top | ||
# pytype: enable=import-error | ||
|
||
|
||
def compact_dtype(values: Sequence[int]) -> Any: | ||
"""Selects the most compact dtype to represent a list of signed integers. | ||
Only supports: int{8, 16, 32}. | ||
Note: Jax operations between unsigned and signed integers can be expensive. | ||
Args: | ||
values: List of integer values. | ||
Returns: | ||
Dtype compatible with all the values. | ||
""" | ||
|
||
if not values: | ||
raise ValueError("No values provided") | ||
|
||
min_value = min(values) | ||
max_value = max(values) | ||
|
||
for candidate in [jnp.int8, jnp.int16, jnp.int32]: | ||
info = jnp.iinfo(candidate) | ||
if min_value >= info.min and max_value <= info.max: | ||
return candidate | ||
raise ValueError("No supported compact dtype") | ||
|
||
|
||
def to_compact_jax_array(values: Sequence[int]) -> jax.Array: | ||
"""Converts a list of integers to a compact Jax array.""" | ||
|
||
return jnp.asarray(values, dtype=compact_dtype(values)) |
50 changes: 50 additions & 0 deletions
50
yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2022 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from ydf.model import export_jax as to_jax | ||
|
||
|
||
class JaxModelTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters( | ||
((0,), jnp.int8), | ||
((0, 1, -1), jnp.int8), | ||
((0, 1, 0x7F, -0x80), jnp.int8), | ||
((0, 1, 0x7F + 1), jnp.int16), | ||
((0, 1, -0x80 - 1), jnp.int16), | ||
((0, 1, 0x7FFF), jnp.int16), | ||
((0, 1, -0x8000), jnp.int16), | ||
((0, 1, 0x7FFF + 1), jnp.int32), | ||
((0, 1, -0x8000 - 1), jnp.int32), | ||
((0, 1, 0x7FFFFFFF), jnp.int32), | ||
((0, 1, -0x80000000), jnp.int32), | ||
) | ||
def test_compact_dtype(self, values, expected_dtype): | ||
self.assertEqual(to_jax.compact_dtype(values), expected_dtype) | ||
|
||
jax_array = to_jax.to_compact_jax_array(values) | ||
self.assertEqual(jax_array.dtype.type, expected_dtype) | ||
np.testing.assert_array_equal(jax_array, jnp.array(values, expected_dtype)) | ||
|
||
def test_compact_dtype_non_supported(self): | ||
with self.assertRaisesRegex(ValueError, "No supported compact dtype"): | ||
to_jax.compact_dtype((0x80000000,)) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |