Enabling JAX as backend for the GAN training step #8
Labels
enhancement
New feature or request
help wanted
Extra attention is needed
python
Pull requests that update Python code
Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.
At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have
The goal of this issue is to implement the
train_step()
also for the JAX backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the JAX backend.The text was updated successfully, but these errors were encountered: