Skip to content

Commit

Permalink
Adapt to tf-nightly (#1634)
Browse files Browse the repository at this point in the history
* fixed tests in keras_layers_test

* fixed image classifier test

* fix bert tokenizer

* multi branch arch adapt preprocessing layers

* depending on tf-nightly

* coverage

Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Oct 18, 2021
1 parent a50dbc1 commit a5ba53a
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 24 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.6
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: 3.6
python-version: 3.7
- name: Get pip cache dir
id: pip-cache
run: |
Expand Down Expand Up @@ -50,10 +50,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.6
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: 3.6
python-version: 3.7
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
Expand All @@ -73,7 +73,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.6
python-version: 3.7
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel twine
Expand Down
31 changes: 19 additions & 12 deletions autokeras/engine/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import copy
import os

Expand Down Expand Up @@ -112,26 +113,32 @@ def adapt(model, dataset):
# TODO: Use Keras Tuner for preprocessing layers adapt.
x = dataset.map(lambda x, y: x)

def get_output_layer(tensor):
def get_output_layers(tensor):
output_layers = []
tensor = nest.flatten(tensor)[0]
for layer in model.layers:
if isinstance(layer, tf.keras.layers.InputLayer):
continue
input_node = nest.flatten(layer.input)[0]
if input_node is tensor:
if not isinstance(layer, preprocessing.PreprocessingLayer):
break
return layer
return None
if isinstance(layer, preprocessing.PreprocessingLayer):
output_layers.append(layer)
return output_layers

dq = collections.deque()

for index, input_node in enumerate(nest.flatten(model.input)):
temp_x = x.map(lambda *args: nest.flatten(args)[index])
layer = get_output_layer(input_node)
while layer is not None:
if isinstance(layer, preprocessing.PreprocessingLayer):
layer.adapt(temp_x)
temp_x = temp_x.map(layer)
layer = get_output_layer(layer.output)
in_x = x.map(lambda *args: nest.flatten(args)[index])
for layer in get_output_layers(input_node):
dq.append((layer, in_x))

while len(dq):
layer, in_x = dq.popleft()
layer.adapt(in_x)
out_x = in_x.map(layer)
for next_layer in get_output_layers(layer.output):
dq.append((next_layer, out_x))

return model

def search(
Expand Down
16 changes: 13 additions & 3 deletions autokeras/keras_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_config(self):
def call(self, inputs):
return data_utils.cast_to_float32(inputs)

def adapt(self, data):
return


@tf.keras.utils.register_keras_serializable()
class ExpandLastDim(preprocessing.PreprocessingLayer):
Expand All @@ -49,6 +52,9 @@ def get_config(self):
def call(self, inputs):
return tf.expand_dims(inputs, axis=-1)

def adapt(self, data):
return


@tf.keras.utils.register_keras_serializable()
class MultiCategoryEncoding(preprocessing.PreprocessingLayer):
Expand All @@ -75,9 +81,7 @@ def __init__(self, encoding: List[str], **kwargs):
# Set a temporary vocabulary to prevent the error of no
# vocabulary when calling the layer to build the model. The
# vocabulary would be reset by adapting the layer later.
self.encoding_layers.append(
preprocessing.StringLookup(vocabulary=["NONE"])
)
self.encoding_layers.append(preprocessing.StringLookup())
elif encoding == ONE_HOT:
self.encoding_layers.append(None)

Expand Down Expand Up @@ -190,6 +194,9 @@ def bert_encode(self, input_tensor):

return input_word_ids

def adapt(self, data):
return # pragma: no cover


# TODO: Remove after KerasNLP is ready.
@tf.keras.utils.register_keras_serializable()
Expand Down Expand Up @@ -685,6 +692,9 @@ def call(self, inputs):

return mask # pragma: no cover

def get_config(self):
return super().get_config()


@tf.keras.utils.register_keras_serializable()
class Transformer(tf.keras.layers.Layer):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ filterwarnings =
ignore::PendingDeprecationWarning
ignore::FutureWarning
ignore::numpy.VisibleDeprecationWarning
ignore::tensorflow.python.keras.utils.generic_utils.CustomMaskWarning
ignore::keras.utils.generic_utils.CustomMaskWarning

addopts=-v
--durations=10
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
install_requires=[
"packaging",
"keras-tuner>=1.0.2",
"tensorflow<=2.5.0,>=2.3.0",
"tf-nightly==2.8.0.dev20211016",
"scikit-learn",
"pandas",
],
Expand All @@ -41,7 +41,6 @@
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Mathematics",
Expand Down
4 changes: 3 additions & 1 deletion tests/unit_tests/keras_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def test_init_multi_one_hot_encode():


def test_call_multi_with_single_column_return_right_shape():
x_train = np.array([["a"], ["b"], ["a"]])
layer = layer_module.MultiCategoryEncoding(encoding=[layer_module.INT])
layer.adapt(tf.data.Dataset.from_tensor_slices(x_train).batch(32))

assert layer(np.array([["a"], ["b"], ["a"]])).shape == (3, 1)
assert layer(x_train).shape == (3, 1)


def get_text_data():
Expand Down

0 comments on commit a5ba53a

Please sign in to comment.