Skip to content

Commit

Permalink
Wrap Keras methods to support BatchEncoding (huggingface#28734)
Browse files Browse the repository at this point in the history
* Shim the Keras methods to support BatchEncoding

* Extract everything to a convert_batch_encoding function

* Convert BatchFeature too (thanks Amy)

* tf.keras -> keras
  • Loading branch information
Rocketknight1 authored Jan 31, 2024
1 parent 721e2d9 commit 7a49610
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import (
convert_batch_encoding,
expand_1d,
load_attributes_from_hdf5_group,
save_attributes_to_hdf5_group,
Expand Down Expand Up @@ -1155,6 +1156,36 @@ def __init__(self, config, *inputs, **kwargs):
def get_config(self):
return self.config.to_dict()

@functools.wraps(keras.Model.fit)
def fit(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().fit(*args, **kwargs)

@functools.wraps(keras.Model.train_on_batch)
def train_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().train_on_batch(*args, **kwargs)

@functools.wraps(keras.Model.test_on_batch)
def test_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().test_on_batch(*args, **kwargs)

@functools.wraps(keras.Model.predict_on_batch)
def predict_on_batch(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().predict_on_batch(*args, **kwargs)

@functools.wraps(keras.Model.predict)
def predict(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().predict(*args, **kwargs)

@functools.wraps(keras.Model.evaluate)
def evaluate(self, *args, **kwargs):
args, kwargs = convert_batch_encoding(*args, **kwargs)
return super().evaluate(*args, **kwargs)

@classmethod
def from_config(cls, config, **kwargs):
if isinstance(config, PretrainedConfig):
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import tensorflow as tf

from .feature_extraction_utils import BatchFeature
from .tokenization_utils_base import BatchEncoding
from .utils import logging


Expand Down Expand Up @@ -253,3 +255,13 @@ def _expand_single_1d_tensor(t):
return t

return tf.nest.map_structure(_expand_single_1d_tensor, data)


def convert_batch_encoding(*args, **kwargs):
# Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands
if args and isinstance(args[0], (BatchEncoding, BatchFeature)):
args = list(args)
args[0] = dict(args[0])
elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)):
kwargs["x"] = dict(kwargs["x"])
return args, kwargs

0 comments on commit 7a49610

Please sign in to comment.