Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop in timing performance from Keras 2 to Keras 3 with TensorFlow as backend #19953

Closed
mbarbetti opened this issue Jul 3, 2024 · 8 comments
Closed

Comments

@mbarbetti
Copy link

After having completed the upgrade of a Python package intended for High Energy Physics flash-simulation (mbarbetti/pidgan) to be compatible with Keras 3, I have noticed a significant drop in timing performance passing from Keras 2 to Keras 3, as also reported in mbarbetti/pidgan#10.

After some iterations, I have been able to reproduce the problem that seems strictly related to Keras 3 (with the TensorFlow backend).

The code that I have used as a reference follows:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np

from time import time

chunk_size = int(250_000)
x = np.random.normal(size=(chunk_size, 4))
y = np.random.choice([0.0, 1.0], size=chunk_size, p=[0.5, 0.5])

model = keras.Sequential()
for _ in range(10):
    model.add(keras.layers.Dense(128, activation="relu"))
    model.add(keras.layers.Dropout(rate=0.1))
model.add(keras.layers.Dense(1, activation="sigmoid"))

model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=keras.losses.BinaryCrossentropy(label_smoothing=0.05),
    metrics=[keras.metrics.AUC(name="auc")],
    jit_compile=False
)

start = time()
model.fit(x=x, y=y, batch_size=500, epochs=20, validation_split=0.2)
stop = time()

print(f"{stop - start} s")

The above Python script has been executed within two different conda environments: one based on TensorFlow 2.14.1 (with Keras 2.14.0), while the other based on TensorFlow 2.16.1 (with Keras 3.3.3). This exercise has been repeated on three different devices (CPU-only + 2 different GPU cards) and in all the cases I have observed a drop in timing performance passing from Keras 2 to Keras 3.

The device details and the time measured are reported in the following table:

CPU model # cores RAM GPU model GPU partition time on TF2.14 time on TF2.16
AMD EPYC 7282 8 8 GB - - 73.1612 s 79.3706 s
Intel Xeon Gold 5218 8 8 GB Quadro RTX 5000 1/1 51.1899 s 65.2110 s
AMD EPYC 7513 8 8 GB NVIDIA A100 80GB 1/7 46.0913 s 63.2520 s
@mbarbetti
Copy link
Author

Let me precise that all the devices are provisioned as virtual machines with OpenStack as a hypervisor

@fchollet
Copy link
Collaborator

fchollet commented Jul 4, 2024

Note that any benchmark should only measure time starting after the first training step, otherwise you area including the initial startup overhead -- but what you need to measure is the step time.

If you care about performance you should:

  • Use jit_compile=True
  • Tune your batch size
  • Tune the value for steps_per_execution

@mbarbetti
Copy link
Author

Hi @fchollet,

You are right, my naive consideration was that considering enough epochs, such overhead could be neglected but it is probably not the case with 20 epochs only. To stem the problem I have prepared this custom callback intending to measure the average step time in each epoch:

import keras
import numpy as np
from time import time

class BatchTime(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
                
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_times = []
        
    def on_train_batch_begin(self, batch, logs=None):
        self.start = time()

    def on_train_batch_end(self, batch, logs=None):
        self.epoch_times.append(time() - self.start)
    
    def on_epoch_end(self, epoch, logs=None):
        try:
            step_time = keras.ops.mean(self.epoch_times)
        except AttributeError:
            step_time = np.mean(self.epoch_times)
        self.times.append(step_time)

This callback has been passed within the fit() method and the average step times of the whole training are computed after excluding the first epoch to remove the aforementioned overhead:

# [...]

batch_time = BatchTime()

start = time()
model.fit(x=x, y=y, batch_size=500, epochs=20, validation_split=0.2, callbacks=[batch_time])
stop = time()

step_time = 1e3 * np.mean(batch_time.times[1:])  # in ms

print(f"Average step time: {step_time:.4f} ms")
print(f"Total training time: {stop - start:.4f} s")

This "new configuration" has been used to repeat the exercise originally described and the results are reported as a reference in the following table:

CPU model GPU model average step time* on TF2.14 total time on TF2.14 average step time* on TF2.16 total time on TF2.16
AMD EPYC 7282 - 8.0969 ms 75.5528 s 8.4981 ms 81.4652 s
Intel Xeon Gold 5218 Quadro RTX 5000 4.8602 ms 52.8518 s 6.3447 ms 64.3022 s
AMD EPYC 7513 NVIDIA A100 80GB 4.0156 ms 43.7599 s 5.7462 ms 56.0178 s

*the average step time is measured excluding the first epoch for a more robust performance evaluation

@mbarbetti
Copy link
Author

From @fchollet:

If you care about performance you should:

  • Use jit_compile=True
  • Tune your batch size
  • Tune the value for steps_per_execution

Even if I agree with the role of batch-size and number of steps per epoch for performance optimization, this "study" aimed to compare Keras 2 with Keras 3 in terms of timing with common (and "randomly chosen") hyperparameters (i.e., batch-size, number of epochs, dataset size).

For completeness, I have also tried to repeat the usual exercises using jit_compile=True. In this configuration, Keras 3 is able to defeat Keras 2 in terms of timing performance in each of the tested devices. Surprisingly (not really indeed), such a result emerges only by looking at the step time since the total time indicates the opposite direction. As usual, all the details are reported in the following table:

CPU model GPU model average step time* on TF2.14 total time on TF2.14 average step time* on TF2.16 total time on TF2.16
AMD EPYC 7282 - 12.3829 ms 112.9680 s 11.9046 ms 113.3379 s
Intel Xeon Gold 5218 Quadro RTX 5000 1.9773 ms 26.9512 s 1.6777 ms 32.8582 s
AMD EPYC 7513 NVIDIA A100 80GB 1.9033 ms 25.7018 s 1.5725 ms 32.1113 s

*the average step time is measured excluding the first epoch for a more robust performance evaluation

@fchollet
Copy link
Collaborator

fchollet commented Jul 7, 2024

One more thing: you should pick the right backend.

If you care a lot about overhead timing, then you can use the torch backend which has minimal overhead (but executes slower, typically). For large workloads it is often the case that JAX is the fastest backend.

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Jul 25, 2024
Copy link

github-actions bot commented Aug 8, 2024

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

@github-actions github-actions bot closed this as completed Aug 8, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants