-
Notifications
You must be signed in to change notification settings - Fork 396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial for Tuner Subclassing #136
Changes from all commits
174b611
75397ed
4a0c351
509dcb5
8210eb0
0380f4f
28fe5a7
6fb0660
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Subclassing Tuner for Custom Training Loops | ||
|
||
The `Tuner` class at `kerastuner.engine.tuner.Tuner` can be subclassed to support advanced uses such as: | ||
|
||
- Custom training loops (GANs, reinforement learning, etc.) | ||
- Adding hyperparameters outside of the model builing function (preprocessing, data augmentation, test time augmentation, etc.) | ||
|
||
This tutorial will not cover subclassing to support non-Keras models. To accomplish this, you can subclass the `kerastuner.engine.base_tuner.BaseTuner` class (See `kerastuner.tuners.sklearn.Sklearn` for an example). | ||
|
||
### Understanding the search process. | ||
|
||
`Tuner.search` can be passed any arguments. These arguments will be passed directly to `Tuner.run_trial`, along with a `Trial` object that contains information about the current trial, including hyperparameters and the status of the trial. Typically, `Tuner.run_trial` is the only method that users need to override when subclassing `Tuner`. | ||
|
||
### Overriding `run_trial`. | ||
|
||
There are two ways to write `run_trial`. One is to leverage `Tuner`'s built-in callback hooks, which send the value of the `objective` to the `Oracle` and save the latest state of the Model. These hooks are: | ||
|
||
* `self.on_epoch_end`: Must be called. Reports results to the `Oracle` and saves the Model. The `logs` dictionary passed to this method must contain the `objective` name. | ||
* `self.on_epoch_begin`, `self.on_batch_begin`, `self.on_batch_end`: Optional. These methods do nothing in `Tuner`, but are useful to provide as hooks if you expect users of your subclass to create their own subclasses that override these parts of the training process. | ||
|
||
```python | ||
class MyTuner(kt.Tuner): | ||
|
||
def run_trial(self, trial, ...): | ||
model = self.hypermodel.build(trial.hyperparameters) | ||
for epoch in range(10): | ||
epoch_loss = ... | ||
self.on_epoch_end(trial, model, epoch, logs={'loss': epoch_loss}) | ||
``` | ||
|
||
Alternatively, you can instead directly call the methods used to report results to the `Oracle` and save the Model. This can allow more flexibility for use cases where there is no natural concept of epoch or where you do not want to report results to the `Oracle` after each epoch. These methods are: | ||
|
||
* `self.oracle.update_trial`: Reports current results to the `Oracle`. The `metrics` dictionary passed to this method must contain the `objective` name. | ||
* `self.save_model`: Saves the trained model. | ||
|
||
```python | ||
class MyTuner(kt.Tuner): | ||
|
||
def run_trial(self, trial, ...): | ||
model = self.hypermodel.build(trial.hyperparameters) | ||
score = ... | ||
self.oracle.update_trial(trial.trial_id, {'score': score}) | ||
self.oracle.save_model(trail.trial_id, model) | ||
``` | ||
|
||
### Adding HyperParameters during preprocessing, evaluation, etc. | ||
|
||
New `HyperParameter`s can be defined anywhere in `run_trial`, in the same way that `HyperParameter`s are defined in a `HyperModel`. These hyperparameters take on their default value the first time they are encountered, and thereafter are tuned by the `Oracle`. | ||
|
||
```python | ||
class MyTuner(kt.Tuner): | ||
|
||
def run_trial(self, trial, ...): | ||
hp = trial.hyperparameters | ||
model = self.hypermodel.build(hp) | ||
|
||
batch_size = hp.Int('batch_size', 32, 128, step=32) | ||
random_flip = hp.Boolean('random_flip') | ||
... | ||
``` | ||
|
||
### End-to-end Example: | ||
|
||
```python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a complex example. We should consider using a format that enables us to make the example directly runnable in Colab in one click, and easily testable in an automatically generated integration test. For instance, examples on tensorflow.org are typically shipped as runnable notebooks. The notebooks are used to generated a HTML webpage (it goes through markdown first), can be launched in Colab, and are run as part of integration tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to do the equivalent on GH? I couldn't find anything from googling around a bit I ran this example on my computer locally to make sure it worked, is that ok for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no ready-made solution that I'm aware of, we'd have to roll out our own (and make it shared cross all of our repos). We'll do this eventually, but we can't block tutorials based on this, so let's ship the tutorial as is for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, made a tracking bug here: keras-team/keras-autodoc#35 |
||
import kerastuner as kt | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
|
||
def build_model(hp): | ||
"""Builds a convolutional model.""" | ||
inputs = tf.keras.Input(shape=(28, 28, 1)) | ||
x = inputs | ||
for i in range(hp.Int('conv_layers', 1, 3, default=3)): | ||
x = tf.keras.layers.Conv2D( | ||
filters=hp.Int('filters_' + str(i), 4, 32, step=4, default=8), | ||
kernel_size=hp.Int('kernel_size_' + str(i), 3, 5), | ||
activation='relu', | ||
padding='same')(x) | ||
|
||
if hp.Choice('pooling' + str(i), ['max', 'avg']) == 'max': | ||
x = tf.keras.layers.MaxPooling2D()(x) | ||
else: | ||
x = tf.keras.layers.AveragePooling2D()(x) | ||
|
||
x = tf.keras.layers.BatchNormalization()(x) | ||
x = tf.keras.layers.ReLU()(x) | ||
|
||
if hp.Choice('global_pooling', ['max', 'avg']) == 'max': | ||
x = tf.keras.layers.GlobalMaxPooling2D()(x) | ||
else: | ||
x = tf.keras.layers.GlobalAveragePooling2D()(x) | ||
outputs = tf.keras.layers.Dense(10, activation='softmax')(x) | ||
|
||
model = tf.keras.Model(inputs, outputs) | ||
|
||
optimizer = hp.Choice('optimizer', ['adam', 'sgd']) | ||
model.compile(optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) | ||
return model | ||
|
||
|
||
class MyTuner(kt.Tuner): | ||
|
||
def run_trial(self, trial, train_ds): | ||
hp = trial.hyperparameters | ||
|
||
# Hyperparameters can be added anywhere inside `run_trial`. | ||
# When the first trial is run, they will take on their default values. | ||
# Afterwards, they will be tuned by the `Oracle`. | ||
train_ds = train_ds.batch( | ||
hp.Int('batch_size', 32, 128, step=32, default=64)) | ||
|
||
model = self.hypermodel.build(trial.hyperparameters) | ||
lr = hp.Float('learning_rate', 1e-4, 1e-2, sampling='log', default=1e-3) | ||
optimizer = tf.keras.optimizers.Adam(lr) | ||
epoch_loss_metric = tf.keras.metrics.Mean() | ||
|
||
@tf.function | ||
def run_train_step(data): | ||
images = tf.dtypes.cast(data['image'], 'float32') / 255. | ||
labels = data['label'] | ||
with tf.GradientTape() as tape: | ||
logits = model(images) | ||
loss = tf.keras.losses.sparse_categorical_crossentropy( | ||
labels, logits) | ||
# Add any regularization losses. | ||
if model.losses: | ||
loss += tf.math.add_n(model.losses) | ||
gradients = tape.gradient(loss, model.trainable_variables) | ||
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) | ||
epoch_loss_metric.update_state(loss) | ||
return loss | ||
|
||
# `self.on_epoch_end` reports results to the `Oracle` and saves the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be specified in the text above ("Typically, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
# current state of the Model. The other hooks called here only log values | ||
# for display but can also be overridden. For use cases where there is no | ||
# natural concept of epoch, you do not have to call any of these hooks. In | ||
# this case you should instead call `self.oracle.update_trial` and | ||
# `self.oracle.save_model` manually. | ||
for epoch in range(10): | ||
print('Epoch: {}'.format(epoch)) | ||
|
||
self.on_epoch_begin(trial, model, epoch, logs={}) | ||
for batch, data in enumerate(train_ds): | ||
self.on_batch_begin(trial, model, batch, logs={}) | ||
batch_loss = float(run_train_step(data)) | ||
self.on_batch_end(trial, model, batch, logs={'loss': batch_loss}) | ||
|
||
if batch % 100 == 0: | ||
loss = epoch_loss_metric.result().numpy() | ||
print('Batch: {}, Average Loss: {}'.format(batch, loss)) | ||
|
||
epoch_loss = epoch_loss_metric.result().numpy() | ||
self.on_epoch_end(trial, model, epoch, logs={'loss': epoch_loss}) | ||
epoch_loss_metric.reset_states() | ||
|
||
|
||
def main(): | ||
tuner = MyTuner( | ||
oracle=kt.oracles.BayesianOptimization( | ||
objective=kt.Objective('loss', 'min'), | ||
max_trials=2), | ||
hypermodel=build_model, | ||
directory='results', | ||
project_name='mnist_custom_training') | ||
|
||
mnist_data = tfds.load('mnist') | ||
mnist_train, mnist_test = mnist_data['train'], mnist_data['test'] | ||
mnist_train = mnist_train.shuffle(1000) | ||
|
||
tuner.search(train_ds=mnist_train) | ||
|
||
best_hps = tuner.get_best_hyperparameters()[0] | ||
print(best_hps.values) | ||
|
||
best_model = tuner.get_best_models()[0] | ||
|
||
if __name__ == '__main__': | ||
main() | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides the detailed example below, I would recommend having a sketch of the code here, e.g.
Otherwise it's hard to parse what's essential and what's not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done