-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Drop in timing performance from Keras 2 to Keras 3 with TensorFlow as backend #19953
Comments
Let me precise that all the devices are provisioned as virtual machines with OpenStack as a hypervisor |
Note that any benchmark should only measure time starting after the first training step, otherwise you area including the initial startup overhead -- but what you need to measure is the step time. If you care about performance you should:
|
Hi @fchollet, You are right, my naive consideration was that considering enough epochs, such overhead could be neglected but it is probably not the case with 20 epochs only. To stem the problem I have prepared this custom callback intending to measure the average step time in each epoch: import keras
import numpy as np
from time import time
class BatchTime(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, epoch, logs=None):
self.epoch_times = []
def on_train_batch_begin(self, batch, logs=None):
self.start = time()
def on_train_batch_end(self, batch, logs=None):
self.epoch_times.append(time() - self.start)
def on_epoch_end(self, epoch, logs=None):
try:
step_time = keras.ops.mean(self.epoch_times)
except AttributeError:
step_time = np.mean(self.epoch_times)
self.times.append(step_time) This callback has been passed within the # [...]
batch_time = BatchTime()
start = time()
model.fit(x=x, y=y, batch_size=500, epochs=20, validation_split=0.2, callbacks=[batch_time])
stop = time()
step_time = 1e3 * np.mean(batch_time.times[1:]) # in ms
print(f"Average step time: {step_time:.4f} ms")
print(f"Total training time: {stop - start:.4f} s") This "new configuration" has been used to repeat the exercise originally described and the results are reported as a reference in the following table:
*the average step time is measured excluding the first epoch for a more robust performance evaluation |
From @fchollet:
Even if I agree with the role of batch-size and number of steps per epoch for performance optimization, this "study" aimed to compare Keras 2 with Keras 3 in terms of timing with common (and "randomly chosen") hyperparameters (i.e., batch-size, number of epochs, dataset size). For completeness, I have also tried to repeat the usual exercises using
*the average step time is measured excluding the first epoch for a more robust performance evaluation |
One more thing: you should pick the right backend. If you care a lot about overhead timing, then you can use the torch backend which has minimal overhead (but executes slower, typically). For large workloads it is often the case that JAX is the fastest backend. |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further. |
After having completed the upgrade of a Python package intended for High Energy Physics flash-simulation (mbarbetti/pidgan) to be compatible with Keras 3, I have noticed a significant drop in timing performance passing from Keras 2 to Keras 3, as also reported in mbarbetti/pidgan#10.
After some iterations, I have been able to reproduce the problem that seems strictly related to Keras 3 (with the TensorFlow backend).
The code that I have used as a reference follows:
The above Python script has been executed within two different
conda
environments: one based on TensorFlow 2.14.1 (with Keras 2.14.0), while the other based on TensorFlow 2.16.1 (with Keras 3.3.3). This exercise has been repeated on three different devices (CPU-only + 2 different GPU cards) and in all the cases I have observed a drop in timing performance passing from Keras 2 to Keras 3.The device details and the time measured are reported in the following table:
The text was updated successfully, but these errors were encountered: