Skip to content

Commit

Permalink
Merge pull request #99 from serengil/feat-task-3103-tf-216
Browse files Browse the repository at this point in the history
make retinaface compatible with tf2.16 and later
  • Loading branch information
serengil authored Mar 31, 2024
2 parents 059fcdf + 175fdb8 commit a4fd2b8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion package_info.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.0.15"
"version": "0.0.16"
}
8 changes: 8 additions & 0 deletions retinaface/RetinaFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import logging
from typing import Union, Any, Optional, Dict

# this has to be set before importing tf
os.environ["TF_USE_LEGACY_KERAS"] = "1"

# pylint: disable=wrong-import-position
import numpy as np
import tensorflow as tf

from retinaface import __version__
from retinaface.model import retinaface_model
from retinaface.commons import preprocess, postprocess
from retinaface.commons.logger import Logger
from retinaface.commons import package_utils

# users should install tf_keras package if they are using tf 2.16 or later versions
package_utils.validate_for_keras3()

logger = Logger(module="retinaface/RetinaFace.py")

Expand Down
2 changes: 1 addition & 1 deletion retinaface/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.15"
__version__ = "0.0.16"
28 changes: 28 additions & 0 deletions retinaface/commons/package_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# 3rd party dependencies
import tensorflow as tf

# project dependencies
from retinaface.commons.logger import Logger

logger = Logger(module="retinaface/commons/package_utils.py")


def validate_for_keras3():
tf_major = int(tf.__version__.split(".", maxsplit=1)[0])
tf_minor = int(tf.__version__.split(".", maxsplit=-1)[1])

# tf_keras is a must dependency after tf 2.16
if tf_major == 1 or (tf_major == 2 and tf_minor < 16):
return

try:
import tf_keras

logger.debug(f"tf_keras is already available - {tf_keras.__version__}")
except ImportError as err:
# you may consider to install that package here
raise ValueError(
f"You have tensorflow {tf.__version__} and this requires "
"tf-keras package. Please run `pip install tf-keras` "
"or downgrade your tensorflow."
) from err

0 comments on commit a4fd2b8

Please sign in to comment.