-
Notifications
You must be signed in to change notification settings - Fork 395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Please support RL and custom training loops #46
Comments
@bionicles this is definitely on the roadmap, on the Keras Tuner side as well as making it easier on the Keras side to plug custom training step into On the Keras Tuner side, for complicated use cases the current best way imo is to override the |
i'd like to use keras-tuner but im unclear on this codebase and there's zero documentation on how to use the callback hooks. custom != complicated, i'll just use optuna for now |
@bionicles There will be more documentation ahead of the 1.0 launch (coming soon!). The codebase is still in active development, so some things for advanced use cases are still internal but will be exposed externally soon. For now, here's an example of how to implement a custom training loop that should remain stable: class MyTuner(kerastuner.engine.Tuner):
def run_trial(self, trial, x, y, val_x, val_y):
model = self.hypermodel.build(trial.hyperparameters)
for epoch in range(10):
self.on_epoch_begin(trial, model, epoch, logs={})
for batch in range(100):
self.on_batch_begin(trial, model, batch, logs={})
outputs = model(x)
loss = myloss(outputs, y)
self.on_batch_end(trial, model, batch, logs={'loss': loss})
val_loss = ...
self.on_epoch_end(trial, model, epochs, logs={'val_loss': val_loss})
def build_model(hp):
num_layers = hp.Int('num_layers', 1, 10)
...
return model
tuner = MyTuner(
oracle=kerastuner.tuners.randomsearch.RandomSearchOracle(
objective='val_loss',
max_trials=30),
hypermodel=build_model)
tuner.search(x, y, val_x, val_y) |
I'll update this thread again when we have documentation available for best practices for these use cases |
Closing as there is now a pending PR for a tutorial on custom training loops: #136 |
To "pull" hyperparameters is a nice way to write less code, I really like the API for this repo, but we use a custom GradientTape training loop for multiple different multi-step tasks, model.compile/model.fit doesnt work for us
how do you use this repo for custom training loops?
The text was updated successfully, but these errors were encountered: