Skip to content

Commit

Permalink
Set env var to hold Keras at Keras 2 (huggingface#29598)
Browse files Browse the repository at this point in the history
* Set env var to hold Keras at Keras 2

* Add Amy's update

* make fixup

* Use a warning instead
  • Loading branch information
Rocketknight1 authored Mar 12, 2024
1 parent b640486 commit df15425
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@
if TYPE_CHECKING:
from . import PreTrainedTokenizerBase

logger = logging.get_logger(__name__)

if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
logger.warning(
"Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
"This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
)

try:
import tf_keras as keras
from tf_keras import backend as K
Expand All @@ -93,7 +103,6 @@
)


logger = logging.get_logger(__name__)
tf_logger = tf.get_logger()

TFModelInputType = Union[
Expand Down

0 comments on commit df15425

Please sign in to comment.