From 5aef5ad4d9775b979e540c075bc146369c8adc06 Mon Sep 17 00:00:00 2001 From: brunneis Date: Mon, 31 Jul 2023 18:30:41 +0000 Subject: [PATCH] Add support for Hugging Face Transformers 4.x, upgrade libraries. --- ernie/.flake8 | 3 + ernie/.style.yapf | 9 ++ ernie/__init__.py | 4 +- ernie/aggregation_strategies.py | 33 ++---- ernie/ernie.py | 183 +++++++++++++++++--------------- ernie/helper.py | 24 +++-- ernie/models.py | 41 +++---- ernie/split_strategies.py | 39 +++---- examples/binary_classifier.py | 22 ++-- requirements.txt | 10 +- setup.py | 14 +-- test/load_csv.py | 6 +- test/load_model.py | 15 ++- test/predict.py | 2 +- test/split_aggregate.py | 22 ++-- 15 files changed, 233 insertions(+), 194 deletions(-) create mode 100644 ernie/.flake8 create mode 100644 ernie/.style.yapf diff --git a/ernie/.flake8 b/ernie/.flake8 new file mode 100644 index 0000000..a067e33 --- /dev/null +++ b/ernie/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 79 +ignore = W503 diff --git a/ernie/.style.yapf b/ernie/.style.yapf new file mode 100644 index 0000000..74fcbf6 --- /dev/null +++ b/ernie/.style.yapf @@ -0,0 +1,9 @@ +[style] +based_on_style=pep8 +column_limit=79 +split_before_arithmetic_operator=true +split_before_logical_operator=true +split_before_named_assigns=true +split_before_first_argument=true +allow_split_before_dict_value=false +dedent_closing_brackets=true diff --git a/ernie/__init__.py b/ernie/__init__.py index a2a2380..3344544 100644 --- a/ernie/__init__.py +++ b/ernie/__init__.py @@ -5,7 +5,7 @@ from tensorflow.python.client import device_lib import logging -__version__ = '1.0.1' +__version__ = '1.2307.0' logging.getLogger().setLevel(logging.WARNING) logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) @@ -18,7 +18,7 @@ def _get_cpu_name(): import cpuinfo cpu_info = cpuinfo.get_cpu_info() - cpu_name = f"{cpu_info['brand']}, {cpu_info['count']} vCores" + cpu_name = f"{cpu_info['brand_raw']}, {cpu_info['count']} vCores" return cpu_name diff --git a/ernie/aggregation_strategies.py b/ernie/aggregation_strategies.py index d6bb578..1360f95 100644 --- a/ernie/aggregation_strategies.py +++ b/ernie/aggregation_strategies.py @@ -6,11 +6,7 @@ class AggregationStrategy: def __init__( - self, - method, - max_items=None, - top_items=True, - sorting_class_index=1 + self, method, max_items=None, top_items=True, sorting_class_index=1 ): self.method = method self.max_items = max_items @@ -36,8 +32,11 @@ def aggregate(self, softmax_tuples): softmax_list = [] for key in softmax_dicts[0].keys(): - softmax_list.append(self.method( - [probabilities[key] for probabilities in softmax_dicts])) + softmax_list.append( + self.method( + [probabilities[key] for probabilities in softmax_dicts] + ) + ) softmax_tuple = tuple(softmax_list) return softmax_tuple @@ -45,26 +44,14 @@ def aggregate(self, softmax_tuples): class AggregationStrategies: Mean = AggregationStrategy(method=mean) MeanTopFiveBinaryClassification = AggregationStrategy( - method=mean, - max_items=5, - top_items=True, - sorting_class_index=1 + method=mean, max_items=5, top_items=True, sorting_class_index=1 ) MeanTopTenBinaryClassification = AggregationStrategy( - method=mean, - max_items=10, - top_items=True, - sorting_class_index=1 + method=mean, max_items=10, top_items=True, sorting_class_index=1 ) MeanTopFifteenBinaryClassification = AggregationStrategy( - method=mean, - max_items=15, - top_items=True, - sorting_class_index=1 + method=mean, max_items=15, top_items=True, sorting_class_index=1 ) MeanTopTwentyBinaryClassification = AggregationStrategy( - method=mean, - max_items=20, - top_items=True, - sorting_class_index=1 + method=mean, max_items=20, top_items=True, sorting_class_index=1 ) diff --git a/ernie/ernie.py b/ernie/ernie.py index a61c01d..e4a7923 100644 --- a/ernie/ernie.py +++ b/ernie/ernie.py @@ -23,13 +23,7 @@ AggregationStrategy, AggregationStrategies ) -from .helper import ( - get_features, - softmax, - remove_dir, - make_dir, - copy_dir -) +from .helper import (get_features, softmax, remove_dir, make_dir, copy_dir) AUTOSAVE_PATH = './ernie-autosave/' @@ -39,13 +33,15 @@ def clean_autosave(): class SentenceClassifier: - def __init__(self, - model_name=Models.BertBaseUncased, - model_path=None, - max_length=64, - labels_no=2, - tokenizer_kwargs=None, - model_kwargs=None): + def __init__( + self, + model_name=Models.BertBaseUncased, + model_path=None, + max_length=64, + labels_no=2, + tokenizer_kwargs=None, + model_kwargs=None, + ): self._loaded_data = False self._model_path = None @@ -60,7 +56,11 @@ def __init__(self, if model_path is not None: self._load_local_model(model_path) else: - self._load_remote_model(model_name, tokenizer_kwargs, model_kwargs) + self._load_remote_model( + model_name, + tokenizer_kwargs, + model_kwargs, + ) @property def model(self): @@ -70,13 +70,15 @@ def model(self): def tokenizer(self): return self._tokenizer - def load_dataset(self, - dataframe=None, - validation_split=0.1, - random_state=None, - stratify=None, - csv_path=None, - read_csv_kwargs=None): + def load_dataset( + self, + dataframe=None, + validation_split=0.1, + random_state=None, + stratify=None, + csv_path=None, + read_csv_kwargs=None, + ): if dataframe is None and csv_path is None: raise ValueError @@ -88,9 +90,7 @@ def load_dataset(self, labels = dataframe[dataframe.columns[1]].values ( - training_sentences, - validation_sentences, - training_labels, + training_sentences, validation_sentences, training_labels, validation_labels ) = train_test_split( sentences, @@ -102,14 +102,13 @@ def load_dataset(self, ) self._training_features = get_features( - self._tokenizer, training_sentences, training_labels) + self._tokenizer, training_sentences, training_labels + ) self._training_size = len(training_sentences) self._validation_features = get_features( - self._tokenizer, - validation_sentences, - validation_labels + self._tokenizer, validation_sentences, validation_labels ) self._validation_split = len(validation_sentences) @@ -118,20 +117,22 @@ def load_dataset(self, self._loaded_data = True - def fine_tune(self, - epochs=4, - learning_rate=2e-5, - epsilon=1e-8, - clipnorm=1.0, - optimizer_function=keras.optimizers.Adam, - optimizer_kwargs=None, - loss_function=keras.losses.SparseCategoricalCrossentropy, - loss_kwargs=None, - accuracy_function=keras.metrics.SparseCategoricalAccuracy, - accuracy_kwargs=None, - training_batch_size=32, - validation_batch_size=64, - **kwargs): + def fine_tune( + self, + epochs=4, + learning_rate=2e-5, + epsilon=1e-8, + clipnorm=1.0, + optimizer_function=keras.optimizers.Adam, + optimizer_kwargs=None, + loss_function=keras.losses.SparseCategoricalCrossentropy, + loss_kwargs=None, + accuracy_function=keras.metrics.SparseCategoricalAccuracy, + accuracy_kwargs=None, + training_batch_size=32, + validation_batch_size=64, + **kwargs, + ): if not self._loaded_data: raise Exception('Data has not been loaded.') @@ -154,9 +155,11 @@ def fine_tune(self, self._model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy]) training_features = self._training_features.shuffle( - self._training_size).batch(training_batch_size).repeat(-1) + self._training_size + ).batch(training_batch_size).repeat(-1) validation_features = self._validation_features.batch( - validation_batch_size) + validation_batch_size + ) training_steps = self._training_size // training_batch_size if training_steps == 0: @@ -169,12 +172,14 @@ def fine_tune(self, logging.info(f'validation_steps: {validation_steps}') for i in range(epochs): - self._model.fit(training_features, - epochs=1, - validation_data=validation_features, - steps_per_epoch=training_steps, - validation_steps=validation_steps, - **kwargs) + self._model.fit( + training_features, + epochs=1, + validation_data=validation_features, + steps_per_epoch=training_steps, + validation_steps=validation_steps, + **kwargs + ) # The fine-tuned model does not have the same input interface # after being exported and loaded again. @@ -184,20 +189,23 @@ def predict_one( self, text, split_strategy=None, - aggregation_strategy=None + aggregation_strategy=None, ): return next( - self.predict([text], - batch_size=1, - split_strategy=split_strategy, - aggregation_strategy=aggregation_strategy)) + self.predict( + [text], + batch_size=1, + split_strategy=split_strategy, + aggregation_strategy=aggregation_strategy, + ) + ) def predict( self, texts, batch_size=32, split_strategy=None, - aggregation_strategy=None + aggregation_strategy=None, ): if split_strategy is None: yield from self._predict_batch(texts, batch_size) @@ -235,7 +243,11 @@ def _dump(self, path): self._tokenizer.save_pretrained(path + '/tokenizer') self._config.save_pretrained(path + '/tokenizer') - def _predict_batch(self, sentences: list, batch_size: int): + def _predict_batch( + self, + sentences: list, + batch_size: int, + ): sentences_number = len(sentences) if batch_size > sentences_number: batch_size = sentences_number @@ -252,17 +264,17 @@ def _predict_batch(self, sentences: list, batch_size: int): features = self._tokenizer.encode_plus( sentences[j], add_special_tokens=True, - max_length=self._tokenizer.max_len + max_length=self._tokenizer.model_max_length, ) input_ids, _, attention_mask = ( - features['input_ids'], - features['token_type_ids'], + features['input_ids'], features['token_type_ids'], features['attention_mask'] ) input_ids = self._list_to_padded_array(features['input_ids']) attention_mask = self._list_to_padded_array( - features['attention_mask']) + features['attention_mask'] + ) input_ids_list.append(input_ids) attention_mask_list.append(attention_mask) @@ -273,13 +285,15 @@ def _predict_batch(self, sentences: list, batch_size: int): } logit_predictions = self._model.predict_on_batch(input_dict) yield from ( - [softmax(logit_prediction) - for logit_prediction in logit_predictions[0]] + [ + softmax(logit_prediction) + for logit_prediction in logit_predictions[0] + ] ) def _list_to_padded_array(self, items): array = np.array(items) - padded_array = np.zeros(self._tokenizer.max_len, dtype=np.int) + padded_array = np.zeros(self._tokenizer.model_max_length, dtype=np.int) padded_array[:array.shape[0]] = array return padded_array @@ -288,38 +302,47 @@ def _get_temporary_path(self, name=''): def _reload_model(self): self._model_path = self._get_temporary_path( - name=self._get_model_family()) + name=self._get_model_family() + ) self._dump(self._model_path) self._load_local_model(self._model_path) def _load_local_model(self, model_path): try: self._tokenizer = AutoTokenizer.from_pretrained( - model_path + '/tokenizer') + model_path + '/tokenizer' + ) self._config = AutoConfig.from_pretrained( - model_path + '/tokenizer') + model_path + '/tokenizer' + ) # Old models didn't use to have a tokenizer folder except OSError: self._tokenizer = AutoTokenizer.from_pretrained(model_path) self._config = AutoConfig.from_pretrained(model_path) self._model = TFAutoModelForSequenceClassification.from_pretrained( - model_path, - from_pt=False + model_path, from_pt=False ) def _get_model_family(self): model_family = ''.join(self._model.name[2:].split('_')[:2]) return model_family - def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): + def _load_remote_model( + self, + model_name, + tokenizer_kwargs, + model_kwargs, + ): do_lower_case = False if 'uncased' in model_name.lower(): do_lower_case = True tokenizer_kwargs.update({'do_lower_case': do_lower_case}) self._tokenizer = AutoTokenizer.from_pretrained( - model_name, **tokenizer_kwargs) + model_name, + **tokenizer_kwargs, + ) self._config = AutoConfig.from_pretrained(model_name) temporary_path = self._get_temporary_path() @@ -329,11 +352,11 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): try: self._model = TFAutoModelForSequenceClassification.from_pretrained( model_name, - from_pt=False + from_pt=False, ) # PyTorch model - except TypeError: + except OSError: try: self._model = \ TFAutoModelForSequenceClassification.from_pretrained( @@ -368,9 +391,7 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): getattr(self._model, self._get_model_family() ).save_pretrained(temporary_path) self._model = self._model.__class__.from_pretrained( - temporary_path, - from_pt=False, - **model_kwargs + temporary_path, from_pt=False, **model_kwargs ) # The model is itself the main layer @@ -378,9 +399,7 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): # TensorFlow model try: self._model = self._model.__class__.from_pretrained( - model_name, - from_pt=False, - **model_kwargs + model_name, from_pt=False, **model_kwargs ) # PyTorch Model @@ -388,9 +407,7 @@ def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs): model = AutoModel.from_pretrained(model_name) model.save_pretrained(temporary_path) self._model = self._model.__class__.from_pretrained( - temporary_path, - from_pt=True, - **model_kwargs + temporary_path, from_pt=True, **model_kwargs ) remove_dir(temporary_path) diff --git a/ernie/helper.py b/ernie/helper.py index aa2483c..754aaf9 100644 --- a/ernie/helper.py +++ b/ernie/helper.py @@ -13,11 +13,13 @@ def get_features(tokenizer, sentences, labels): inputs = tokenizer.encode_plus( sentence, add_special_tokens=True, - max_length=tokenizer.max_len + max_length=tokenizer.model_max_length, ) - input_ids, token_type_ids = \ - inputs['input_ids'], inputs['token_type_ids'] - padding_length = tokenizer.max_len - len(input_ids) + input_ids, token_type_ids = ( + inputs['input_ids'], + inputs['token_type_ids'], + ) + padding_length = tokenizer.model_max_length - len(input_ids) if tokenizer.padding_side == 'right': attention_mask = [1] * len(input_ids) + [0] * padding_length @@ -30,7 +32,7 @@ def get_features(tokenizer, sentences, labels): token_type_ids = \ [tokenizer.pad_token_type_id] * padding_length + token_type_ids - assert tokenizer.max_len \ + assert tokenizer.model_max_length \ == len(attention_mask) \ == len(input_ids) \ == len(token_type_ids) @@ -57,11 +59,13 @@ def gen(): dataset = data.Dataset.from_generator( gen, - ({ - 'input_ids': int32, - 'attention_mask': int32, - 'token_type_ids': int32 - }, int64), + ( + { + 'input_ids': int32, + 'attention_mask': int32, + 'token_type_ids': int32 + }, int64 + ), ( { 'input_ids': TensorShape([None]), diff --git a/ernie/models.py b/ernie/models.py index 81596b8..e59b996 100644 --- a/ernie/models.py +++ b/ernie/models.py @@ -29,23 +29,28 @@ class Models: class ModelsByFamily: - Bert = set([Models.BertBaseUncased, Models.BertBaseCased, - Models.BertLargeUncased, Models.BertLargeCased]) + Bert = set( + [ + Models.BertBaseUncased, Models.BertBaseCased, + Models.BertLargeUncased, Models.BertLargeCased + ] + ) Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased]) XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased]) - DistilBert = set([Models.DistilBertBaseUncased, - Models.DistilBertBaseMultilingualCased]) - Albert = set([ - Models.AlbertBaseCased, - Models.AlbertLargeCased, - Models.AlbertXLargeCased, - Models.AlbertXXLargeCased, - Models.AlbertBaseCased2, - Models.AlbertLargeCased2, - Models.AlbertXLargeCased2, - Models.AlbertXXLargeCased2 - ]) - Supported = set([ - getattr(Models, model_type) for model_type - in filter(lambda x: x[:2] != '__', Models.__dict__.keys()) - ]) + DistilBert = set( + [Models.DistilBertBaseUncased, Models.DistilBertBaseMultilingualCased] + ) + Albert = set( + [ + Models.AlbertBaseCased, Models.AlbertLargeCased, + Models.AlbertXLargeCased, Models.AlbertXXLargeCased, + Models.AlbertBaseCased2, Models.AlbertLargeCased2, + Models.AlbertXLargeCased2, Models.AlbertXXLargeCased2 + ] + ) + Supported = set( + [ + getattr(Models, model_type) for model_type in + filter(lambda x: x[:2] != '__', Models.__dict__.keys()) + ] + ) diff --git a/ernie/split_strategies.py b/ernie/split_strategies.py index 4a0b47f..5c88b16 100644 --- a/ernie/split_strategies.py +++ b/ernie/split_strategies.py @@ -50,7 +50,7 @@ def len_in_tokens(text_): return no_tokens no_special_tokens = len(tokenizer.encode('', add_special_tokens=True)) - max_tokens = tokenizer.max_len - no_special_tokens + max_tokens = tokenizer.model_max_length - no_special_tokens if self.remove_patterns is not None: for remove_pattern in self.remove_patterns: @@ -67,7 +67,8 @@ def len_in_tokens(text_): if len_in_tokens(split) > max_tokens: if len(split_patterns) > 1: sub_splits = self.split( - split, tokenizer, split_patterns[1:]) + split, tokenizer, split_patterns[1:] + ) selected_splits.extend(sub_splits) else: selected_splits.append(split) @@ -94,8 +95,8 @@ def len_in_tokens(text_): if not remove_too_short_groups: final_splits = selected_splits else: - final_splits = [] - min_length = tokenizer.max_len / 2 + final_splits + min_length = tokenizer.model_max_length / 2 for split in selected_splits: if len_in_tokens(split) >= min_length: final_splits.append(split) @@ -104,22 +105,22 @@ def len_in_tokens(text_): class SplitStrategies: - SentencesWithoutUrls = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, - RegexExpressions.split_by_semicolon, - RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma - ], + SentencesWithoutUrls = SplitStrategy( + split_patterns=[ + RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, RegexExpressions.split_by_comma + ], remove_patterns=[RegexExpressions.url, RegexExpressions.domain], remove_too_short_groups=False, - group_splits=False) - - GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[ - RegexExpressions.split_by_dot, - RegexExpressions.split_by_semicolon, - RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma - ], + group_splits=False + ) + + GroupedSentencesWithoutUrls = SplitStrategy( + split_patterns=[ + RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, + RegexExpressions.split_by_colon, RegexExpressions.split_by_comma + ], remove_patterns=[RegexExpressions.url, RegexExpressions.domain], remove_too_short_groups=True, - group_splits=True) + group_splits=True + ) diff --git a/examples/binary_classifier.py b/examples/binary_classifier.py index a3fe821..4809e48 100644 --- a/examples/binary_classifier.py +++ b/examples/binary_classifier.py @@ -6,20 +6,26 @@ tuples = [ ("This is a positive example. I'm very happy today.", 1), - ("This is a negative sentence. Everything was wrong today at work.", 0) + ("This is a negative sentence. Everything was wrong today at work.", 0), ] df = pd.DataFrame(tuples) classifier = SentenceClassifier( - model_name=Models.BertBaseUncased, max_length=128, labels_no=2) + model_name=Models.BertBaseUncased, + max_length=128, + labels_no=2, +) + classifier.load_dataset(df, validation_split=0.2) -classifier.fine_tune(epochs=4, learning_rate=2e-5, - training_batch_size=32, validation_batch_size=64) +classifier.fine_tune( + epochs=4, + learning_rate=2e-5, + training_batch_size=32, + validation_batch_size=64, +) sentence = "Oh, that's great!" probability = classifier.predict_one(sentence)[1] -print( - f"\"{sentence}\": {probability} " - f"[{'positive' if probability >= 0.5 else 'negative'}]" -) +print(f"\"{sentence}\": {probability} " + f"[{'positive' if probability >= 0.5 else 'negative'}]") diff --git a/requirements.txt b/requirements.txt index efa724d..d15951e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers>=2.4.1, < 2.5.0 -scikit-learn>=0.22.1, < 1.0.0 -pandas>=0.25.3, < 1.0.0 -tensorflow>=2.5.1, < 2.6.0 -py-cpuinfo>=5.0.0, < 6.0.0 +transformers>=4.24.0, <5.0.0 +scikit-learn>=1.2.1, <2.0.0 +pandas>=1.5.3, <2.0.0 +tensorflow>=2.5.1, <2.11.0 +py-cpuinfo>=9.0.0, <10.0.0 diff --git a/setup.py b/setup.py index e20f825..a5ffaaa 100755 --- a/setup.py +++ b/setup.py @@ -12,9 +12,9 @@ 'An Accessible Python Library for State-of-the-art ' 'Natural Language Processing. Built with HuggingFace\'s Transformers.' ), - url='https://github.com/brunneis/ernie', + url='https://github.com/labteral/ernie', author='Rodrigo Martínez Castaño', - author_email='rodrigo@martinez.gal', + author_email='dev@brunneis.com', license='Apache License (Version 2.0)', packages=find_packages(), zip_safe=False, @@ -30,9 +30,9 @@ ], python_requires=">=3.6", install_requires=[ - 'transformers>=2.4.1, < 2.5.0', - 'scikit-learn>=0.22.1, < 1.0.0', - 'pandas>=0.25.3, < 1.0.0', - 'tensorflow>=2.5.1, < 2.6.0', - 'py-cpuinfo>=5.0.0, < 6.0.0' + 'transformers>=4.24.0, <5.0.0', + 'scikit-learn>=1.2.1, <2.0.0', + 'pandas>=1.5.3, <2.0.0', + 'tensorflow>=2.5.1, <2.11.0', + 'py-cpuinfo>=9.0.0, <10.0.0', ]) diff --git a/test/load_csv.py b/test/load_csv.py index 2b12165..00a7c81 100644 --- a/test/load_csv.py +++ b/test/load_csv.py @@ -9,18 +9,18 @@ class TestLoadCsv(unittest.TestCase): classifier = SentenceClassifier( model_name=Models.BertBaseUncased, max_length=128, - labels_no=2 + labels_no=2, ) classifier.load_dataset( validation_split=0.2, csv_path="example.csv", - read_csv_kwargs={"header": None} + read_csv_kwargs={"header": None}, ) classifier.fine_tune( epochs=4, learning_rate=2e-5, training_batch_size=32, - validation_batch_size=64 + validation_batch_size=64, ) def test_predict(self): diff --git a/test/load_model.py b/test/load_model.py index 19851d5..6449425 100644 --- a/test/load_model.py +++ b/test/load_model.py @@ -10,15 +10,22 @@ class TestLoadModel(unittest.TestCase): tuples = [ ("This is a negative sentence. Everything was wrong today.", 0), ("This is a positive example. I'm very happy today.", 1), - ("This is a neutral sentence. That's normal.", 2) + ("This is a neutral sentence. That's normal.", 2), ] df = pd.DataFrame(tuples) classifier = SentenceClassifier( - model_name='xlm-roberta-large', max_length=128, labels_no=3) + model_name='xlm-roberta-large', + max_length=128, + labels_no=3, + ) classifier.load_dataset(df, validation_split=0.2) - classifier.fine_tune(epochs=4, learning_rate=2e-5, - training_batch_size=32, validation_batch_size=64) + classifier.fine_tune( + epochs=4, + learning_rate=2e-5, + training_batch_size=32, + validation_batch_size=64, + ) def test_predict(self): text = "Oh, that's great!" diff --git a/test/predict.py b/test/predict.py index b8b2034..90d88fe 100644 --- a/test/predict.py +++ b/test/predict.py @@ -9,7 +9,7 @@ class TestPredict(unittest.TestCase): classifier = SentenceClassifier( model_name=Models.BertBaseUncased, max_length=128, - labels_no=2 + labels_no=2, ) def test_batch_predict(self): diff --git a/test/split_aggregate.py b/test/split_aggregate.py index db7a01c..5b70849 100644 --- a/test/split_aggregate.py +++ b/test/split_aggregate.py @@ -7,7 +7,7 @@ Models, AggregationStrategy, SplitStrategy, - RegexExpressions + RegexExpressions, ) from statistics import mean import logging @@ -66,7 +66,7 @@ def test_split_groups(self): RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma + RegexExpressions.split_by_comma, ], remove_patterns=[RegexExpressions.url, RegexExpressions.domain], remove_too_short_groups=False, @@ -85,7 +85,7 @@ def test_split_groups(self): sentence = "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63." # noqa: E501 expected_sentences = [ "0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62.", # noqa: E501 - "63." + "63.", ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -94,7 +94,7 @@ def test_split_groups(self): sentence = "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62; 63;" # noqa: E501 expected_sentences = [ "0; 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 13; 14; 15; 16; 17; 18; 19; 20; 21; 22; 23; 24; 25; 26; 27; 28; 29; 30; 31; 32; 33; 34; 35; 36; 37; 38; 39; 40; 41; 42; 43; 44; 45; 46; 47; 48; 49; 50; 51; 52; 53; 54; 55; 56; 57; 58; 59; 60; 61; 62;", # noqa: E501 - "63;" + "63;", ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -103,7 +103,7 @@ def test_split_groups(self): sentence = "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62: 63: " # noqa: E501 expected_sentences = [ "0: 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62:", # noqa: E501 - "63:" + "63:", ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -112,7 +112,7 @@ def test_split_groups(self): sentence = "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, " # noqa: E501 expected_sentences = [ "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,", # noqa: E501 - "63," + "63,", ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -121,7 +121,7 @@ def test_split_groups(self): sentence = "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63, 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 expected_sentences = [ "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63,", # noqa: E501 - "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127" # noqa: E501 + "64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -146,7 +146,7 @@ def test_split_sentences(self): ], remove_patterns=[RegexExpressions.url, RegexExpressions.domain], remove_too_short_groups=False, - group_splits=False + group_splits=False, ) # 128 tokens + 2 special tokens => @@ -175,7 +175,7 @@ def test_split_sentences(self): expected_sentences = [ "0 1 2 3 4 5 6,", "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", # noqa: E501 - "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 + "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences) @@ -186,7 +186,7 @@ def test_split_groups_remove_too_short(self): RegexExpressions.split_by_dot, RegexExpressions.split_by_semicolon, RegexExpressions.split_by_colon, - RegexExpressions.split_by_comma + RegexExpressions.split_by_comma, ], remove_patterns=[RegexExpressions.url, RegexExpressions.domain], remove_too_short_groups=True, @@ -199,7 +199,7 @@ def test_split_groups_remove_too_short(self): sentence = "0 1 2 3 4 5 6, 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127. 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 expected_sentences = [ "7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127.", # noqa: E501 - "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255" # noqa: E501 + "128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255", # noqa: E501 ] sentences = splitter.split(sentence, self.classifier.tokenizer) self.assertEqual(sentences, expected_sentences)