Skip to content

Commit

Permalink
fix: Remove check on keras version
Browse files Browse the repository at this point in the history
* To support Python 3.7+, require tensorflow v2.5.0+ in the extras, which will
  correctly manage compatible versions of keras.

* The check on the version of keras was required to deal with issues between
  Keras v2.2.4 and v2.2.5. As the minimum required Keras version for
  tensorflow v2.5.0 is keras v2.5.0.dev then all Keras versions installed
  through the EnergyFlow extras will be newer than v2.2.5 and so the
  guard check is no longer necessary.

   - c.f. https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/tools/pip_package/setup.py#L107
  • Loading branch information
matthewfeickert committed Mar 21, 2024
1 parent 8751f78 commit 7b07c0e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
19 changes: 5 additions & 14 deletions energyflow/archs/efn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np

import tensorflow.keras.backend as K
from keras import __version__ as __keras_version__
from tensorflow.keras.layers import Concatenate, Dense, Dot, Dropout, Input, Lambda, TimeDistributed
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
Expand All @@ -33,20 +32,13 @@
# weight mask constructor functions
#'construct_efn_weight_mask', 'construct_pfn_weight_mask',

# network consstructor functions
#'construct_distributed_dense', 'construct_latent', 'construct_dense',
# network constructor functions
#'construct_distributed_dense', 'construct_latent', 'construct_dense',

# full model classes
'EFN', 'PFN'
]

################################################################################
# Keras 2.2.5 fixes bug in 2.2.4 that affects our usage of the Dot layer
################################################################################

if __keras_version__.endswith('-tf'):
__keras_version__ = __keras_version__[:-3]
keras_version_tuple = tuple(map(int, __keras_version__.split('.')))
DOT_AXIS = 1

################################################################################
Expand Down Expand Up @@ -533,20 +525,19 @@ def eval_filters(self, patch, n=100, prune=True):
XY = np.asarray([X, Y]).reshape((1, 2, nx*ny)).transpose((0, 2, 1))

# handle weirdness of Keras/tensorflow
old_keras = (keras_version_tuple <= (2, 2, 5))
s = self.Phi_sizes[-1] if len(self.Phi_sizes) else self.input_dim
in_t, out_t = self.inputs[1], self._tensors[self._tensor_inds['latent'][0] - 1]

# construct function
kf = K.function([in_t] if old_keras else in_t, [out_t] if old_keras else out_t)
kf = K.function(in_t, out_t)

# evaluate function
Z = kf([XY] if old_keras else XY)[0].reshape(nx, ny, s).transpose((2, 0, 1))
Z = kf(XY)[0].reshape(nx, ny, s).transpose((2, 0, 1))

# prune filters that are off
if prune:
return X, Y, Z[[not (z == 0).all() for z in Z]]

return X, Y, Z


Expand Down
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ long_description = file: README.md
long_description_content_type = text/markdown

url = https://energyflow.network
project_urls =
project_urls =
Source Code = https://github.com/pkomiske/EnergyFlow
Issues = https://github.com/pkomiske/EnergyFlow/issues

Expand Down Expand Up @@ -98,25 +98,25 @@ generation =
python-igraph

examples =
tensorflow > 2.0.0
tensorflow >= 2.5.0
scikit-learn
matplotlib

archs =
tensorflow > 2.0.0
tensorflow >= 2.5.0
scikit-learn

tests =
pot >= 0.8.0
pytest
python-igraph
python-igraph == 0.8.3; python_version=='2.7'
tensorflow > 2.0.0
tensorflow >= 2.5.0
scikit-learn

all =
python-igraph
tensorflow > 2.0.0
tensorflow >= 2.5.0
scikit-learn

[bdist_wheel]
Expand Down

0 comments on commit 7b07c0e

Please sign in to comment.