Skip to content

Commit

Permalink
Enable the experimental_debug_stripper in TF2 when saving the `tran…
Browse files Browse the repository at this point in the history
…sform_fn`

PiperOrigin-RevId: 608950695
  • Loading branch information
tf-transform-team authored and tfx-copybara committed Feb 21, 2024
1 parent fb7688c commit f1d1605
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 57 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Explicitly use Keras 2 or `tf_keras`` if Keras 3 is installed.
* Added python 3.11 support.
* Depends on `tensorflow>=2.15.0,<3`.
* Enable passing `tf.saved_model.SaveOptions` to model saving functionality.

## Breaking Changes

Expand Down
2 changes: 2 additions & 0 deletions tensorflow_transform/beam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import uuid

import apache_beam as beam
import tensorflow as tf
from tensorflow_transform import common_types
from tensorflow_transform import nodes
from tfx_bsl.telemetry import util
Expand Down Expand Up @@ -165,6 +166,7 @@ class ExtraArgs:
cache_pcoll_dict: Optional[Dict[str, beam.PCollection]]
preprocessing_fn: Any
analyzers_fingerprint: Mapping[str, Any]
save_options: tf.saved_model.SaveOptions

def __init__(self, extra_args):
self._extra_args = extra_args
Expand Down
19 changes: 17 additions & 2 deletions tensorflow_transform/beam/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class Context:
force_tf_compat_v1: (Optional) If True, TFT's public APIs
(e.g. AnalyzeDataset) will use Tensorflow in compat.v1 mode irrespective
of installed version of Tensorflow. Defaults to `False`.
save_options: (Optional) If set, the tf.saved_model.SaveOptions to save
the transform_fn with. Only applies for TF2.
Note that the temp dir should be accessible to worker jobs, e.g. if running
with the Cloud Dataflow runner, the temp dir should be on GCS and should have
Expand All @@ -56,6 +58,7 @@ class _State:
passthrough_keys: Optional[Iterable[str]] = None
use_deep_copy_optimization: Optional[bool] = None
force_tf_compat_v1: Optional[bool] = None
save_options: Optional[tf.saved_model.SaveOptions] = None

@classmethod
def make_empty(cls):
Expand All @@ -80,7 +83,8 @@ def __init__(self,
desired_batch_size: Optional[int] = None,
passthrough_keys: Optional[Iterable[str]] = None,
use_deep_copy_optimization: Optional[bool] = None,
force_tf_compat_v1: Optional[bool] = None):
force_tf_compat_v1: Optional[bool] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None):
state = getattr(self._thread_local, 'state', None)
if not state:
self._thread_local.state = self._StateStack()
Expand All @@ -92,6 +96,7 @@ def __init__(self,
self._passthrough_keys = passthrough_keys
self._use_deep_copy_optimization = use_deep_copy_optimization
self._force_tf_compat_v1 = force_tf_compat_v1
self._save_options = save_options

def __enter__(self):
# Previous State's properties are inherited if not explicitly specified.
Expand All @@ -110,7 +115,8 @@ def __enter__(self):
last_frame.use_deep_copy_optimization,
force_tf_compat_v1=self._force_tf_compat_v1
if self._force_tf_compat_v1 is not None else
last_frame.force_tf_compat_v1))
last_frame.force_tf_compat_v1,
save_options=self._save_options or last_frame.save_options))

def __exit__(self, *exn_info):
self._thread_local.state.frames.pop()
Expand Down Expand Up @@ -175,3 +181,12 @@ def get_use_tf_compat_v1(cls) -> bool:
"""Computes use_tf_compat_v1 from TF environment and force_tf_compat_v1."""
force_tf_compat_v1 = cls._get_force_tf_compat_v1()
return tf2_utils.use_tf_compat_v1(force_tf_compat_v1)

@classmethod
def get_save_options(cls) -> Optional[tf.saved_model.SaveOptions]:
"""Retrieves a user set save_options, None if not set."""
state = cls._get_topmost_state_frame()
if state.save_options is not None:
tf.compat.v1.logging.info('Using save_options: %s', state.save_options)
return state.save_options
return None
13 changes: 9 additions & 4 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def expand(self, inputs):
def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
preprocessing_fn, input_signature,
baseline_analyzers_fingerprint,
output_keys_to_name_map):
output_keys_to_name_map, save_options):
"""Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.
The SavedModel written contains a method called `transform_fn` that
Expand All @@ -669,6 +669,7 @@ def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
paths that define its fingerprint.
output_keys_to_name_map: A map from output dictionary keys to the names of
the tensors that they represent.
save_options: The tf.saved_model.SaveOptions to save the model with.
Returns:
Path to which SavedModel was written.
Expand All @@ -678,7 +679,8 @@ def _create_v2_saved_model(tensor_replacement_map, base_temp_dir,
input_signature, base_temp_dir,
baseline_analyzers_fingerprint,
tensor_replacement_map,
output_keys_to_name_map)
output_keys_to_name_map,
save_options)
return saved_model_dir


Expand All @@ -695,6 +697,7 @@ def __init__(self, operation, extra_args):
self._input_signature = extra_args.input_specs
self._output_signature = operation.output_signature
self._analyzers_fingerprint = extra_args.analyzers_fingerprint
self._save_options = extra_args.save_options

def _maybe_get_output_tensor_names_dict(self):
# output_signature will contain CompositeTensors only if this is the final
Expand All @@ -719,7 +722,7 @@ def expand(self, inputs):
| 'CreateSavedModel' >> beam.Map(
_create_v2_saved_model, self._base_temp_dir, self._preprocessing_fn,
self._input_signature, self._analyzers_fingerprint,
self._maybe_get_output_tensor_names_dict())
self._maybe_get_output_tensor_names_dict(), self._save_options)
| 'Count' >>
beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME))

Expand Down Expand Up @@ -988,6 +991,7 @@ def __init__(self, preprocessing_fn, pipeline=None):
"""
self._preprocessing_fn = preprocessing_fn
self.pipeline = pipeline
self._save_options = Context.get_save_options()
self._use_tf_compat_v1 = Context.get_use_tf_compat_v1()
if self._use_tf_compat_v1:
_warn_about_tf_compat_v1()
Expand Down Expand Up @@ -1155,7 +1159,8 @@ def expand(self, dataset):
use_tf_compat_v1=self._use_tf_compat_v1,
cache_pcoll_dict=dataset_cache_dict,
preprocessing_fn=self._preprocessing_fn,
analyzers_fingerprint=analyzers_fingerprint)
analyzers_fingerprint=analyzers_fingerprint,
save_options=self._save_options)

(transform_fn_future, cache_value_nodes,
detached_sideeffect_leafs) = analysis_graph_builder.build(
Expand Down
115 changes: 76 additions & 39 deletions tensorflow_transform/impl_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,103 +434,138 @@ def trace_preprocessing_function(preprocessing_fn,
0. the graph representing the traced `preprocessing_fn`
1. the graph's structured inputs
2. the graph's structured outputs
"""
if use_tf_compat_v1:
return _trace_preprocessing_fn_v1(preprocessing_fn, input_specs)
else:
return _trace_preprocessing_fn_v2(preprocessing_fn, input_specs,
base_temp_dir)
return _trace_preprocessing_fn_v2(
preprocessing_fn, input_specs, base_temp_dir
)


def _trace_and_write_transform_fn(
saved_model_dir: str,
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
preprocessing_fn: Callable[
[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType],
],
input_signature: Mapping[str, tf.TypeSpec],
base_temp_dir: Optional[str],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
output_keys_to_name_map: Optional[Dict[str,
str]]) -> function.ConcreteFunction:
output_keys_to_name_map: Optional[Dict[str, str]],
save_options: Optional[tf.saved_model.SaveOptions],
) -> function.ConcreteFunction:
"""Trace `preprocessing_fn` and serialize to a SavedModel."""
tf_graph_context = graph_context.TFGraphContext(
module_to_export=tf.Module(),
temp_dir=base_temp_dir,
evaluated_replacements=tensor_replacement_map)
evaluated_replacements=tensor_replacement_map,
)
transform_fn = get_traced_transform_fn(
preprocessing_fn,
input_signature,
tf_graph_context,
output_keys_to_name_map=output_keys_to_name_map)
output_keys_to_name_map=output_keys_to_name_map,
)
return saved_transform_io_v2.write_v2_saved_model(
tf_graph_context.module_to_export, transform_fn, 'transform_fn',
saved_model_dir)
tf_graph_context.module_to_export,
transform_fn,
'transform_fn',
saved_model_dir,
save_options,
)


def _trace_and_get_metadata(
concrete_transform_fn: function.ConcreteFunction,
structured_inputs: Mapping[str, common_types.TensorType],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
preprocessing_fn: Callable[
[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType],
],
base_temp_dir: Optional[str],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
) -> dataset_metadata.DatasetMetadata:
"""Compute and return metadata for the outputs of `concrete_transform_fn`."""
tf_graph_context = graph_context.TFGraphContext(
module_to_export=tf.Module(),
temp_dir=base_temp_dir,
evaluated_replacements=tensor_replacement_map)
evaluated_replacements=tensor_replacement_map,
)
concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
preprocessing_fn,
structured_inputs,
tf_graph_context,
evaluate_schema_overrides=True)
evaluate_schema_overrides=True,
)
return dataset_metadata.DatasetMetadata(
schema=schema_inference.infer_feature_schema_v2(
concrete_transform_fn.structured_outputs,
concrete_metadata_fn,
evaluate_schema_overrides=True))
evaluate_schema_overrides=True,
)
)


def _validate_analyzers_fingerprint(
baseline_analyzers_fingerprint: Mapping[str,
graph_tools.AnalyzersFingerprint],
graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]):
baseline_analyzers_fingerprint: Mapping[
str, graph_tools.AnalyzersFingerprint
],
graph: tf.Graph,
structured_inputs: Mapping[str, common_types.TensorType],
):
"""Validates analyzers fingerprint in `graph` is same as baseline."""
analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
graph, structured_inputs)
graph, structured_inputs
)
error_msg = (
'The order of analyzers in your `preprocessing_fn` appears to be '
'non-deterministic. This can be fixed either by changing your '
'`preprocessing_fn` such that tf.Transform analyzers are encountered '
'in a deterministic order or by passing a unique name to each '
'analyzer API call.')
'analyzer API call.'
)
for analyzer in analyzers_fingerprint:
if analyzer not in baseline_analyzers_fingerprint:
prefix_msg = (f'Analyzer node ({analyzer}) not found in '
f'{baseline_analyzers_fingerprint.keys()}. ')
prefix_msg = (
f'Analyzer node ({analyzer}) not found in '
f'{baseline_analyzers_fingerprint.keys()}. '
)
raise RuntimeError(prefix_msg + error_msg)
if (baseline_analyzers_fingerprint[analyzer].source_keys !=
analyzers_fingerprint[analyzer].source_keys):
if (
baseline_analyzers_fingerprint[analyzer].source_keys
!= analyzers_fingerprint[analyzer].source_keys
):
raise RuntimeError(error_msg)

if (baseline_analyzers_fingerprint[analyzer].unique_path_hash !=
analyzers_fingerprint[analyzer].unique_path_hash):
if (
baseline_analyzers_fingerprint[analyzer].unique_path_hash
!= analyzers_fingerprint[analyzer].unique_path_hash
):
logging.warning(
'Analyzer (%s) node\'s cache key varies on repeated tracing.'
"Analyzer (%s) node's cache key varies on repeated tracing."
' This warning is safe to ignore if you either specify `name` for all'
' analyzers or if the order in which they are invoked is'
' deterministic. If not, please file a bug with details.', analyzer)
' deterministic. If not, please file a bug with details.',
analyzer,
)


def trace_and_write_v2_saved_model(
saved_model_dir: str,
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
baseline_analyzers_fingerprint: Mapping[str,
graph_tools.AnalyzersFingerprint],
preprocessing_fn: Callable[
[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType],
],
input_signature: Mapping[str, tf.TypeSpec],
base_temp_dir: Optional[str],
baseline_analyzers_fingerprint: Mapping[
str, graph_tools.AnalyzersFingerprint
],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
output_keys_to_name_map: Optional[Dict[str, str]]):
output_keys_to_name_map: Optional[Dict[str, str]],
save_options: Optional[tf.saved_model.SaveOptions],
):
"""Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.
The SavedModel written contains a method called `transform_fn` that
Expand All @@ -549,6 +584,7 @@ def trace_and_write_v2_saved_model(
evaluated replacement tensors.
output_keys_to_name_map: A map from output dictionary keys to the names of
the tensors that they represent.
save_options: The options to use when saving the saved_model.
Returns:
A tuple containing a pair of `tf.ConcreteFunction`s:
Expand All @@ -562,7 +598,7 @@ def trace_and_write_v2_saved_model(
"""
concrete_transform_fn = _trace_and_write_transform_fn(
saved_model_dir, preprocessing_fn, input_signature, base_temp_dir,
tensor_replacement_map, output_keys_to_name_map)
tensor_replacement_map, output_keys_to_name_map, save_options)
structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
concrete_transform_fn.graph)
_validate_analyzers_fingerprint(baseline_analyzers_fingerprint,
Expand Down Expand Up @@ -632,7 +668,8 @@ def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs,
input_signature=type_specs,
base_temp_dir=None,
tensor_replacement_map=None,
output_keys_to_name_map=None)
output_keys_to_name_map=None,
save_options=None)
_assert_no_analyzers_in_graph(concrete_transform_fn.graph)
structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
concrete_transform_fn.graph)
Expand Down
12 changes: 9 additions & 3 deletions tensorflow_transform/saved/saved_transform_io_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Utility functions to save and load from SavedModels in TF 2.x."""

from typing import Any, Dict, Iterable, Mapping, Tuple, Union
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union

import tensorflow as tf
from tensorflow_transform import annotators
Expand Down Expand Up @@ -534,9 +534,15 @@ def write_v2_saved_model(
tf_function: tf.types.experimental.GenericFunction,
name: str,
saved_model_dir: str,
save_options: Optional[tf.saved_model.SaveOptions] = None,
) -> function.ConcreteFunction:
"""Writes `tf_function` under attr `name` of `module` to `saved_model_dir`."""
concrete_fn = trace_and_update_module(
module, tf_function, name, strip_control_dependencies=False)
tf.saved_model.save(module, saved_model_dir)
module, tf_function, name, strip_control_dependencies=False
)
tf.saved_model.save(
module,
saved_model_dir,
options=save_options,
)
return concrete_fn
Loading

0 comments on commit f1d1605

Please sign in to comment.