diff --git a/autokeras/__init__.py b/autokeras/__init__.py index 40c32bed8..cd299e4f6 100644 --- a/autokeras/__init__.py +++ b/autokeras/__init__.py @@ -1,5 +1,4 @@ from autokeras.auto_model import AutoModel -from autokeras.const import Constant from autokeras.hypermodel.base import Block from autokeras.hypermodel.base import Head from autokeras.hypermodel.base import HyperBlock diff --git a/autokeras/auto_model.py b/autokeras/auto_model.py index f6dff2830..16ec31c3e 100644 --- a/autokeras/auto_model.py +++ b/autokeras/auto_model.py @@ -1,3 +1,4 @@ +import numpy as np import tensorflow as tf from tensorflow.python.util import nest @@ -80,6 +81,9 @@ def __init__(self, self.inputs = nest.flatten(inputs) self.outputs = nest.flatten(outputs) self.seed = seed + if seed: + np.random.seed(seed) + tf.random.set_seed(seed) # TODO: Support passing a tuner instance. if isinstance(tuner, str): tuner = tuner_module.get_tuner_class(tuner) diff --git a/autokeras/const.py b/autokeras/const.py index 7a696a4b2..ff324888e 100644 --- a/autokeras/const.py +++ b/autokeras/const.py @@ -1,4 +1,35 @@ - -class Constant(object): - # Text - VOCABULARY_SIZE = 20000 +INITIAL_HPS = { + 'image_classifier': [{ + 'image_block_1/block_type': 'vanilla', + 'image_block_1/normalize': True, + 'image_block_1/augment': False, + 'image_block_1_vanilla/kernel_size': 3, + 'image_block_1_vanilla/num_blocks': 1, + 'image_block_1_vanilla/separable': False, + 'image_block_1_vanilla/dropout_rate': 0.25, + 'image_block_1_vanilla/filters_0_1': 32, + 'image_block_1_vanilla/filters_0_2': 64, + 'spatial_reduction_1/reduction_type': 'flatten', + 'dense_block_1/num_layers': 1, + 'dense_block_1/use_batchnorm': False, + 'dense_block_1/dropout_rate': 0, + 'dense_block_1/units_0': 128, + 'classification_head_1/dropout_rate': 0.5, + 'optimizer': 'adam' + }, { + 'image_block_1/block_type': 'resnet', + 'image_block_1/normalize': True, + 'image_block_1/augment': True, + 'image_block_1_resnet/version': 'v2', + 'image_block_1_resnet/pooling': 'avg', + 'image_block_1_resnet/conv3_depth': 4, + 'image_block_1_resnet/conv4_depth': 6, + 'dense_block_1/num_layers': 2, + 'dense_block_1/use_batchnorm': False, + 'dense_block_1/dropout_rate': 0, + 'dense_block_1/units_0': 32, + 'dense_block_1/units_1': 32, + 'classification_head_1/dropout_rate': 0, + 'optimizer': 'adam' + }], +} diff --git a/autokeras/hypermodel/block.py b/autokeras/hypermodel/block.py index 5ffc9e497..69f01a92e 100644 --- a/autokeras/hypermodel/block.py +++ b/autokeras/hypermodel/block.py @@ -52,10 +52,11 @@ def build(self, hp, inputs=None): num_layers = self.num_layers or hp.Choice('num_layers', [1, 2, 3], default=2) use_batchnorm = self.use_batchnorm if use_batchnorm is None: - use_batchnorm = hp.Choice('use_batchnorm', [True, False], default=False) - dropout_rate = self.dropout_rate or hp.Choice('dropout_rate', - [0.0, 0.25, 0.5], - default=0) + use_batchnorm = hp.Boolean('use_batchnorm', default=False) + if self.dropout_rate is not None: + dropout_rate = self.dropout_rate + else: + dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0) for i in range(num_layers): units = hp.Choice( @@ -66,7 +67,8 @@ def build(self, hp, inputs=None): if use_batchnorm: output_node = tf.keras.layers.BatchNormalization()(output_node) output_node = tf.keras.layers.ReLU()(output_node) - output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) + if dropout_rate > 0: + output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) return output_node @@ -121,7 +123,7 @@ def build(self, hp, inputs=None): bidirectional = self.bidirectional if bidirectional is None: - bidirectional = hp.Choice('bidirectional', [True, False], default=True) + bidirectional = hp.Boolean('bidirectional', default=True) layer_type = self.layer_type or hp.Choice('layer_type', ['gru', 'lstm'], default='lstm') @@ -157,24 +159,30 @@ class ConvBlock(base.Block): tuned automatically. separable: Boolean. Whether to use separable conv layers. If left unspecified, it will be tuned automatically. + dropout_rate: Float. Between 0 and 1. The dropout rate for after the + convolutional layers. If left unspecified, it will be tuned + automatically. """ def __init__(self, kernel_size=None, num_blocks=None, separable=None, + dropout_rate=None, **kwargs): super().__init__(**kwargs) self.kernel_size = kernel_size self.num_blocks = num_blocks self.separable = separable + self.dropout_rate = dropout_rate def get_config(self): config = super().get_config() config.update({ 'kernel_size': self.kernel_size, 'num_blocks': self.num_blocks, - 'separable': self.separable}) + 'separable': self.separable, + 'dropout_rate': self.dropout_rate}) return config def build(self, hp, inputs=None): @@ -191,7 +199,7 @@ def build(self, hp, inputs=None): default=2) separable = self.separable if separable is None: - separable = hp.Choice('separable', [True, False], default=False) + separable = hp.Boolean('separable', default=False) if separable: conv = utils.get_sep_conv(input_node.shape) @@ -199,6 +207,11 @@ def build(self, hp, inputs=None): conv = utils.get_conv(input_node.shape) pool = utils.get_max_pooling(input_node.shape) + if self.dropout_rate is not None: + dropout_rate = self.dropout_rate + else: + dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0) + for i in range(num_blocks): output_node = conv( hp.Choice('filters_{i}_1'.format(i=i), @@ -217,6 +230,8 @@ def build(self, hp, inputs=None): output_node = pool( kernel_size - 1, padding=self._get_padding(kernel_size - 1, output_node))(output_node) + if dropout_rate > 0: + output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) return output_node @staticmethod @@ -546,9 +561,10 @@ def build(self, hp, inputs=None): input_length=input_node.shape[1], trainable=True) output_node = layer(input_node) - dropout_rate = self.dropout_rate or hp.Choice('dropout_rate', - [0.0, 0.25, 0.5], - default=0.25) + if self.dropout_rate is not None: + dropout_rate = self.dropout_rate + else: + dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0.25) if dropout_rate > 0: output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) return output_node diff --git a/autokeras/hypermodel/head.py b/autokeras/hypermodel/head.py index 973839a04..3dab34c69 100644 --- a/autokeras/hypermodel/head.py +++ b/autokeras/hypermodel/head.py @@ -123,13 +123,17 @@ def build(self, hp, inputs=None): input_node = inputs[0] output_node = input_node + # Reduce the tensor to a vector. if len(output_node.shape) > 2: - dropout_rate = self.dropout_rate or hp.Choice('dropout_rate', - [0.0, 0.25, 0.5], - default=0) - if dropout_rate > 0: - output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) output_node = block_module.SpatialReduction().build(hp, output_node) + + if self.dropout_rate is not None: + dropout_rate = self.dropout_rate + else: + dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0) + + if dropout_rate > 0: + output_node = tf.keras.layers.Dropout(dropout_rate)(output_node) output_node = tf.keras.layers.Dense(self.output_shape[-1])(output_node) if self.loss == 'binary_crossentropy': output_node = Sigmoid(name=self.name)(output_node) diff --git a/autokeras/hypermodel/hyperblock.py b/autokeras/hypermodel/hyperblock.py index 4ae86d478..eabc746c6 100644 --- a/autokeras/hypermodel/hyperblock.py +++ b/autokeras/hypermodel/hyperblock.py @@ -25,20 +25,17 @@ def __init__(self, block_type=None, normalize=None, augment=None, - seed=None, **kwargs): super().__init__(**kwargs) self.block_type = block_type self.normalize = normalize self.augment = augment - self.seed = seed def get_config(self): config = super().get_config() config.update({'block_type': self.block_type, 'normalize': self.normalize, - 'augment': self.augment, - 'seed': self.seed}) + 'augment': self.augment}) return config def build(self, hp, inputs=None): @@ -51,10 +48,10 @@ def build(self, hp, inputs=None): normalize = self.normalize if normalize is None: - normalize = hp.Choice('normalize', [True, False], default=True) + normalize = hp.Boolean('normalize', default=True) augment = self.augment if augment is None: - augment = hp.Choice('augment', [True, False], default=False) + augment = hp.Boolean('augment', default=False) if normalize: output_node = preprocessor_module.Normalization()(output_node) if augment: @@ -77,8 +74,9 @@ class TextBlock(base.HyperBlock): vectorizer: String. 'sequence' or 'ngram'. If it is 'sequence', TextToIntSequence will be used. If it is 'ngram', TextToNgramVector will be used. If unspecified, it will be tuned automatically. - pretraining: Boolean. Whether to use pretraining weights in the N-gram - vectorizer. If unspecified, it will be tuned automatically. + pretraining: String. 'random' (use random weights instead any pretrained + model), 'glove', 'fasttext' or 'word2vec'. Use pretrained word embedding. + If left unspecified, it will be tuned automatically. """ def __init__(self, vectorizer=None, pretraining=None, **kwargs): diff --git a/autokeras/hypermodel/preprocessor.py b/autokeras/hypermodel/preprocessor.py index 1131d9df0..a0949cb05 100644 --- a/autokeras/hypermodel/preprocessor.py +++ b/autokeras/hypermodel/preprocessor.py @@ -7,7 +7,6 @@ from sklearn.preprocessing import normalize from tensorflow.python.util import nest -from autokeras import const from autokeras import encoder from autokeras import utils from autokeras.hypermodel import base @@ -88,14 +87,26 @@ def set_state(self, state): class TextToIntSequence(base.Preprocessor): - """Convert raw texts to sequences of word indices.""" + """Convert raw texts to sequences of word indices. - def __init__(self, max_len=None, **kwargs): + # Arguments + max_len: Int. The maximum length of a sentence. If unspecified, the length of + the longest sentence will be used. + num_words: Int. The size of the maximum number of words to keep, based + on word frequency. Only the most common num_words-1 words will be kept. + Defaults to 20000. + """ + + def __init__(self, + max_len=None, + num_words=20000, + **kwargs): super().__init__(**kwargs) self.max_len = max_len self.max_len_in_data = 0 + self.num_words = num_words self.tokenizer = tf.keras.preprocessing.text.Tokenizer( - num_words=const.Constant.VOCABULARY_SIZE) + num_words=num_words) self.max_len_to_use = None self.max_features = None @@ -127,7 +138,10 @@ def output_shape(self): def get_config(self): config = super().get_config() - config.update({'max_len': self.max_len}) + config.update({ + 'max_len': self.max_len, + 'num_words': self.num_words, + }) return config def get_state(self): diff --git a/autokeras/meta_model.py b/autokeras/meta_model.py index 41a3a7693..9d03e3db4 100644 --- a/autokeras/meta_model.py +++ b/autokeras/meta_model.py @@ -28,7 +28,7 @@ def assemble(inputs, outputs, dataset, seed=None): if isinstance(input_node, node.TextInput): assemblers.append(TextAssembler()) if isinstance(input_node, node.ImageInput): - assemblers.append(ImageAssembler(seed=seed)) + assemblers.append(ImageAssembler()) if isinstance(input_node, node.StructuredDataInput): assemblers.append(StructuredDataAssembler(seed=seed)) if isinstance(input_node, node.TimeSeriesInput): @@ -125,9 +125,8 @@ def assemble(self, input_node): class ImageAssembler(Assembler): """Assembles the ImageBlock based on training dataset.""" - def __init__(self, seed=None, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.seed = seed self._shape = None self._num_samples = 0 @@ -136,7 +135,7 @@ def update(self, x): self._num_samples += 1 def assemble(self, input_node): - block = hyperblock.ImageBlock(seed=self.seed) + block = hyperblock.ImageBlock() if max(self._shape[0], self._shape[1]) < 32: if self._num_samples < 10000: self.hps.append(hp_module.Choice( diff --git a/autokeras/oracle.py b/autokeras/oracle.py new file mode 100644 index 000000000..0443c197b --- /dev/null +++ b/autokeras/oracle.py @@ -0,0 +1,155 @@ +import random + +import kerastuner +import numpy as np + +from autokeras.hypermodel import base + + +class GreedyOracle(kerastuner.Oracle): + """An oracle combining random search and greedy algorithm. + + It groups the HyperParameters into several categories, namely, HyperGraph, + Preprocessor, Architecture, and Optimization. The oracle tunes each group + separately using random search. In each trial, it use a greedy strategy to + generate new values for one of the categories of HyperParameters and use the best + trial so far for the rest of the HyperParameters values. + + # Arguments + initial_hps: A list of dictionaries in the form of + {HyperParameter name (String): HyperParameter value}. + Each dictionary is one set of HyperParameters, which are used as the + initial trials for the search. Defaults to None. + seed: Int. Random seed. + """ + + HYPER = 'HYPER' + PREPROCESS = 'PREPROCESS' + OPT = 'OPT' + ARCH = 'ARCH' + STAGES = [HYPER, PREPROCESS, OPT, ARCH] + + @staticmethod + def next_stage(stage): + stages = GreedyOracle.STAGES + return stages[(stages.index(stage) + 1) % len(stages)] + + def __init__(self, + initial_hps=None, + seed=None, + **kwargs): + super().__init__(**kwargs) + self.initial_hps = initial_hps or [] + self._tried_initial_hps = [False] * len(self.initial_hps) + self.hyper_graph = None + # Sets of HyperParameter names. + self._hp_names = { + GreedyOracle.HYPER: set(), + GreedyOracle.PREPROCESS: set(), + GreedyOracle.OPT: set(), + GreedyOracle.ARCH: set(), + } + # The quota used to tune each category of hps. + self.seed = seed or random.randint(1, 1e4) + # Incremented at every call to `populate_space`. + self._seed_state = self.seed + self._tried_so_far = set() + self._max_collisions = 5 + + def set_state(self, state): + super().set_state(state) + + def get_state(self): + state = super().get_state() + state.update({ + }) + return state + + def update_space(self, hyperparameters): + # Get the block names. + preprocess_graph, keras_graph = self.hyper_graph.build_graphs( + hyperparameters) + + # Add the new Hyperparameters to different categories. + ref_names = {hp.name for hp in self.hyperparameters.space} + for hp in hyperparameters.space: + if hp.name not in ref_names: + hp_type = None + if any([hp.name.startswith(block.name) + for block in self.hyper_graph.blocks + if isinstance(block, base.HyperBlock)]): + hp_type = GreedyOracle.HYPER + elif any([hp.name.startswith(block.name) + for block in preprocess_graph.blocks]): + hp_type = GreedyOracle.PREPROCESS + elif any([hp.name.startswith(block.name) + for block in keras_graph.blocks]): + hp_type = GreedyOracle.ARCH + else: + hp_type = GreedyOracle.OPT + self._hp_names[hp_type].add(hp.name) + + super().update_space(hyperparameters) + + def _generate_stage(self): + probabilities = np.array([pow(len(value), 2) + for value in self._hp_names.values()]) + sum_p = np.sum(probabilities) + if sum_p == 0: + probabilities = np.array([1] * len(probabilities)) + sum_p = np.sum(probabilities) + probabilities = probabilities / sum_p + return np.random.choice(list(self._hp_names.keys()), p=probabilities) + + def _next_initial_hps(self): + for index, hps in enumerate(self.initial_hps): + if not self._tried_initial_hps[index]: + self._tried_initial_hps[index] = True + return hps + + def _populate_space(self, trial_id): + if not all(self._tried_initial_hps): + return {'status': kerastuner.engine.trial.TrialStatus.RUNNING, + 'values': self._next_initial_hps()} + + stage = self._generate_stage() + for _ in range(len(GreedyOracle.STAGES)): + values = self._generate_stage_values(stage) + # Reached max collisions. + if values is None: + # Try next stage. + stage = GreedyOracle.next_stage(stage) + continue + # Values found. + return {'status': kerastuner.engine.trial.TrialStatus.RUNNING, + 'values': values} + # All stages reached max collisions. + return {'status': kerastuner.engine.trial.TrialStatus.STOPPED, + 'values': None} + + def _generate_stage_values(self, stage): + best_trials = self.get_best_trials() + if best_trials: + best_values = best_trials[0].hyperparameters.values + else: + best_values = self.hyperparameters.values + collisions = 0 + while True: + # Generate new values for the current stage. + values = {} + for p in self.hyperparameters.space: + if p.name in self._hp_names[stage]: + values[p.name] = p.random_sample(self._seed_state) + self._seed_state += 1 + values = {**best_values, **values} + # Keep trying until the set of values is unique, + # or until we exit due to too many collisions. + values_hash = self._compute_values_hash(values) + if values_hash not in self._tried_so_far: + self._tried_so_far.add(values_hash) + break + collisions += 1 + if collisions > self._max_collisions: + # Reached max collisions. No value to return. + return None + return values diff --git a/autokeras/tuner.py b/autokeras/tuner.py index 678dae42b..cb183a927 100644 --- a/autokeras/tuner.py +++ b/autokeras/tuner.py @@ -1,12 +1,12 @@ import copy import os -import random import kerastuner import kerastuner.engine.hypermodel as hm_module import tensorflow as tf -from autokeras.hypermodel import base +from autokeras import const +from autokeras import oracle as oracle_module class AutoTuner(kerastuner.engine.multi_execution_tuner.MultiExecutionTuner): @@ -182,159 +182,23 @@ class BayesianOptimization(AutoTuner, kerastuner.BayesianOptimization): pass -class GreedyOracle(kerastuner.Oracle): - """An oracle combining random search and greedy algorithm. - - It groups the HyperParameters into several categories, namely, HyperGraph, - Preprocessor, Architecture, and Optimization. The oracle tunes each group - separately using random search. In each trial, it use a greedy strategy to - generate new values for one of the categories of HyperParameters and use the best - trial so far for the rest of the HyperParameters values. - - # Arguments - hyper_graph: HyperGraph. The hyper_graph model to be tuned. - seed: Int. Random seed. - """ - - HYPER = 'HYPER' - PREPROCESS = 'PREPROCESS' - OPT = 'OPT' - ARCH = 'ARCH' - STAGES = [HYPER, PREPROCESS, OPT, ARCH] - - @staticmethod - def next_stage(stage): - stages = GreedyOracle.STAGES - return stages[(stages.index(stage) + 1) % len(stages)] - - def __init__(self, seed=None, **kwargs): - super().__init__(**kwargs) - self.hyper_graph = None - # Start from tuning the hyper block hps. - self._stage = GreedyOracle.HYPER - # Sets of HyperParameter names. - self._hp_names = { - GreedyOracle.HYPER: set(), - GreedyOracle.PREPROCESS: set(), - GreedyOracle.OPT: set(), - GreedyOracle.ARCH: set(), - } - # The quota used to tune each category of hps. - self._capacity = { - GreedyOracle.HYPER: 1, - GreedyOracle.PREPROCESS: 1, - GreedyOracle.OPT: 1, - GreedyOracle.ARCH: 4, - } - self._stage_trial_count = 0 - self.seed = seed or random.randint(1, 1e4) - # Incremented at every call to `populate_space`. - self._seed_state = self.seed - self._tried_so_far = set() - self._max_collisions = 5 - - def set_state(self, state): - super().set_state(state) - self._stage = state['stage'] - self._capacity = state['capacity'] - - def get_state(self): - state = super().get_state() - state.update({ - 'stage': self._stage, - 'capacity': self._capacity, - }) - return state - - def update_space(self, hyperparameters): - # Get the block names. - preprocess_graph, keras_graph = self.hyper_graph.build_graphs( - hyperparameters) - - # Add the new Hyperparameters to different categories. - ref_names = {hp.name for hp in self.hyperparameters.space} - for hp in hyperparameters.space: - if hp.name not in ref_names: - hp_type = None - if any([hp.name.startswith(block.name) - for block in self.hyper_graph.blocks - if isinstance(block, base.HyperBlock)]): - hp_type = GreedyOracle.HYPER - elif any([hp.name.startswith(block.name) - for block in preprocess_graph.blocks]): - hp_type = GreedyOracle.PREPROCESS - elif any([hp.name.startswith(block.name) - for block in keras_graph.blocks]): - hp_type = GreedyOracle.ARCH - else: - hp_type = GreedyOracle.OPT - self._hp_names[hp_type].add(hp.name) - - super().update_space(hyperparameters) - - def _populate_space(self, trial_id): - for _ in range(len(GreedyOracle.STAGES)): - values = self._generate_stage_values() - # Reached max collisions. - if values is None: - # Try next stage. - self._stage = GreedyOracle.next_stage(self._stage) - self._stage_trial_count = 0 - continue - # Values found. - self._stage_trial_count += 1 - if self._stage_trial_count == self._capacity[self._stage]: - self._stage = GreedyOracle.next_stage(self._stage) - self._stage_trial_count = 0 - return {'status': kerastuner.engine.trial.TrialStatus.RUNNING, - 'values': values} - # All stages reached max collisions. - return {'status': kerastuner.engine.trial.TrialStatus.STOPPED, - 'values': None} - - def _generate_stage_values(self): - best_trials = self.get_best_trials() - if best_trials: - best_values = best_trials[0].hyperparameters.values - else: - best_values = self.hyperparameters.values - collisions = 0 - while 1: - # Generate new values for the current stage. - values = {} - for p in self.hyperparameters.space: - if p.name in self._hp_names[self._stage]: - values[p.name] = p.random_sample(self._seed_state) - self._seed_state += 1 - values = {**best_values, **values} - # Keep trying until the set of values is unique, - # or until we exit due to too many collisions. - values_hash = self._compute_values_hash(values) - if values_hash not in self._tried_so_far: - self._tried_so_far.add(values_hash) - break - collisions += 1 - if collisions > self._max_collisions: - # Reached max collisions. No value to return. - return None - return values - - class Greedy(AutoTuner): def __init__(self, hypermodel, objective, max_trials, + initial_hps=None, seed=None, hyperparameters=None, tune_new_entries=True, allow_new_entries=True, **kwargs): self.seed = seed - oracle = GreedyOracle( + oracle = oracle_module.GreedyOracle( objective=objective, max_trials=max_trials, + initial_hps=initial_hps, seed=seed, hyperparameters=hyperparameters, tune_new_entries=tune_new_entries, @@ -349,12 +213,19 @@ def search(self, hyper_graph, **kwargs): super().search(hyper_graph=hyper_graph, **kwargs) +class ImageClassifierTuner(Greedy): + def __init__(self, **kwargs): + super().__init__( + initial_hps=const.INITIAL_HPS['image_classifier'], + **kwargs) + + TUNER_CLASSES = { 'bayesian': BayesianOptimization, 'random': RandomSearch, 'hyperband': Hyperband, 'greedy': Greedy, - 'image_classifier': Greedy, + 'image_classifier': ImageClassifierTuner, 'image_regressor': Greedy, 'text_classifier': Greedy, 'text_regressor': Greedy, diff --git a/examples/cifar10.py b/examples/cifar10.py deleted file mode 100644 index 0f171b890..000000000 --- a/examples/cifar10.py +++ /dev/null @@ -1,39 +0,0 @@ -from tensorflow.keras.datasets import cifar10 - -import autokeras as ak - - -def task_api(): - (x_train, y_train), (x_test, y_test) = cifar10.load_data() - clf = ak.ImageClassifier(seed=5, max_trials=10) - clf.fit(x_train, y_train, validation_split=0.2) - return clf.evaluate(x_test, y_test) - - -def io_api(): - (x_train, y_train), (x_test, y_test) = cifar10.load_data() - clf = ak.AutoModel(ak.ImageInput(), - ak.ClassificationHead(), - seed=5, - max_trials=3) - clf.fit(x_train, y_train, validation_split=0.2) - return clf.evaluate(x_test, y_test) - - -def functional_api(): - (x_train, y_train), (x_test, y_test) = cifar10.load_data() - input_node = ak.ImageInput() - output_node = input_node - output_node = ak.Normalization()(output_node) - output_node = ak.ImageAugmentation()(output_node) - output_node = ak.ResNetBlock(version='next')(output_node) - output_node = ak.SpatialReduction()(output_node) - output_node = ak.DenseBlock()(output_node) - output_node = ak.ClassificationHead()(output_node) - clf = ak.AutoModel(input_node, output_node, seed=5, max_trials=3) - clf.fit(x_train, y_train, validation_split=0.2) - return clf.evaluate(x_test, y_test) - - -if __name__ == '__main__': - functional_api() diff --git a/tests/autokeras/oracle_test.py b/tests/autokeras/oracle_test.py new file mode 100644 index 000000000..2a2ee80d4 --- /dev/null +++ b/tests/autokeras/oracle_test.py @@ -0,0 +1,43 @@ +from unittest import mock + +import kerastuner + +from autokeras import oracle as oracle_module +from tests import common + + +def test_random_oracle_state(): + hyper_graph = common.build_hyper_graph() + oracle = oracle_module.GreedyOracle( + objective='val_loss', + ) + oracle.hyper_graph = hyper_graph + oracle.set_state(oracle.get_state()) + assert oracle.hyper_graph is hyper_graph + + +@mock.patch('autokeras.oracle.GreedyOracle.get_best_trials') +def test_random_oracle(fn): + hyper_graph = common.build_hyper_graph() + oracle = oracle_module.GreedyOracle( + objective='val_loss', + ) + hp = kerastuner.HyperParameters() + preprocess_graph, keras_graph = hyper_graph.build_graphs(hp) + preprocess_graph.build(hp) + keras_graph.inputs[0].shape = hyper_graph.inputs[0].shape + keras_graph.build(hp) + oracle.hyper_graph = hyper_graph + trial = mock.Mock() + trial.hyperparameters = hp + fn.return_value = [trial] + + oracle.update_space(hp) + for i in range(2000): + oracle._populate_space(str(i)) + + assert 'optimizer' in oracle._hp_names[oracle_module.GreedyOracle.OPT] + assert 'classification_head_1/dropout_rate' in oracle._hp_names[ + oracle_module.GreedyOracle.ARCH] + assert 'image_block_1/block_type' in oracle._hp_names[ + oracle_module.GreedyOracle.HYPER] diff --git a/tests/autokeras/tuner_test.py b/tests/autokeras/tuner_test.py index c40fbc7eb..158e6c74c 100644 --- a/tests/autokeras/tuner_test.py +++ b/tests/autokeras/tuner_test.py @@ -4,7 +4,6 @@ import pytest import tensorflow as tf -import autokeras as ak from autokeras import tuner as tuner_module from tests import common @@ -14,22 +13,10 @@ def tmp_dir(tmpdir_factory): return tmpdir_factory.mktemp('test_auto_model') -def build_hyper_graph(): - tf.keras.backend.clear_session() - image_input = ak.ImageInput(shape=(32, 32, 3)) - merged_outputs = ak.ImageBlock()(image_input) - head = ak.ClassificationHead(num_classes=10) - head.output_shape = (10,) - classification_outputs = head(merged_outputs) - return ak.hypermodel.graph.HyperGraph( - inputs=image_input, - outputs=classification_outputs) - - @mock.patch('kerastuner.engine.base_tuner.BaseTuner.search') @mock.patch('autokeras.tuner.Greedy._prepare_run') def test_add_early_stopping(_, base_tuner_search, tmp_dir): - hyper_graph = build_hyper_graph() + hyper_graph = common.build_hyper_graph() hp = kerastuner.HyperParameters() preprocess_graph, keras_graph = hyper_graph.build_graphs(hp) preprocess_graph.build(hp) @@ -53,43 +40,6 @@ def test_add_early_stopping(_, base_tuner_search, tmp_dir): for callback in callbacks]) -def test_random_oracle_state(): - hyper_graph = build_hyper_graph() - oracle = tuner_module.GreedyOracle( - objective='val_loss', - ) - oracle.hyper_graph = hyper_graph - oracle.set_state(oracle.get_state()) - assert oracle.hyper_graph is hyper_graph - - -@mock.patch('autokeras.tuner.GreedyOracle.get_best_trials') -def test_random_oracle(fn): - hyper_graph = build_hyper_graph() - oracle = tuner_module.GreedyOracle( - objective='val_loss', - ) - hp = kerastuner.HyperParameters() - preprocess_graph, keras_graph = hyper_graph.build_graphs(hp) - preprocess_graph.build(hp) - keras_graph.inputs[0].shape = hyper_graph.inputs[0].shape - keras_graph.build(hp) - oracle.hyper_graph = hyper_graph - trial = mock.Mock() - trial.hyperparameters = hp - fn.return_value = [trial] - - oracle.update_space(hp) - for i in range(2000): - oracle._populate_space(str(i)) - - assert 'optimizer' in oracle._hp_names[tuner_module.GreedyOracle.OPT] - assert 'classification_head_1/dropout_rate' in oracle._hp_names[ - tuner_module.GreedyOracle.ARCH] - assert 'image_block_1/block_type' in oracle._hp_names[ - tuner_module.GreedyOracle.HYPER] - - @mock.patch('kerastuner.engine.base_tuner.BaseTuner.__init__') @mock.patch('autokeras.tuner.Greedy._prepare_run') def test_overwrite_init(_, base_tuner_init, tmp_dir): @@ -106,7 +56,7 @@ def test_overwrite_init(_, base_tuner_init, tmp_dir): @mock.patch('kerastuner.engine.base_tuner.BaseTuner.search') @mock.patch('autokeras.tuner.Greedy._prepare_run') def test_overwrite_search(_, base_tuner_search, tmp_dir): - hyper_graph = build_hyper_graph() + hyper_graph = common.build_hyper_graph() hp = kerastuner.HyperParameters() preprocess_graph, keras_graph = hyper_graph.build_graphs(hp) preprocess_graph.build(hp) diff --git a/tests/common.py b/tests/common.py index 756145bd1..81b1e95d3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -235,3 +235,15 @@ def imdb_raw(num_instances=100): def name_in_hps(hp_name, hp): return any([hp_name in name for name in hp.values]) + + +def build_hyper_graph(): + tf.keras.backend.clear_session() + image_input = ak.ImageInput(shape=(32, 32, 3)) + merged_outputs = ak.ImageBlock()(image_input) + head = ak.ClassificationHead(num_classes=10) + head.output_shape = (10,) + classification_outputs = head(merged_outputs) + return ak.hypermodel.graph.HyperGraph( + inputs=image_input, + outputs=classification_outputs)