Skip to content

Commit

Permalink
Merge pull request #21 from labteral/fix-install
Browse files Browse the repository at this point in the history
1.0.0
  • Loading branch information
brunneis authored Dec 22, 2021
2 parents f598d5c + 219f2e1 commit 172f72f
Show file tree
Hide file tree
Showing 16 changed files with 451 additions and 227 deletions.
66 changes: 46 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
<p align="center">
<br>
<a href="https://github.com/brunneis/ernie#stickers-by-sticker-mule" alt="Stickers section"><img src="misc/ernie-sticker-diecut.png" alt="Ernie Logo" width="150"/></a>
<a href="https://github.com/labteral/ernie#stickers-by-sticker-mule" alt="Stickers section"><img src="misc/ernie-sticker-diecut.png" alt="Ernie Logo" width="150"/></a>
<br>
<p>

<p align="center">
<a href="https://pepy.tech/project/ernie/"><img alt="Downloads" src="https://img.shields.io/badge/dynamic/json?style=flat-square&maxAge=3600&label=downloads&query=$.total_downloads&url=https://api.pepy.tech/api/projects/ernie"></a>
<a href="https://pypi.python.org/pypi/ernie/"><img alt="PyPi" src="https://img.shields.io/pypi/v/ernie.svg?style=flat-square"></a>
<!--<a href="https://github.com/brunneis/ernie/releases"><img alt="GitHub releases" src="https://img.shields.io/github/release/brunneis/ernie.svg?style=flat-square"></a>-->
<a href="https://github.com/brunneis/ernie/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/github/license/brunneis/ernie.svg?style=flat-square"></a>
<a href="https://github.com/labteral/ernie/releases"><img alt="GitHub releases" src="https://img.shields.io/github/release/labteral/ernie.svg?style=flat-square"></a>
<a href="https://github.com/labteral/ernie/blob/master/LICENSE"><img alt="License" src="https://img.shields.io/github/license/labteral/ernie.svg?style=flat-square"></a>
</p>

<h3 align="center">
Expand All @@ -32,13 +32,24 @@ pip install ernie
from ernie import SentenceClassifier, Models
import pandas as pd

tuples = [("This is a positive example. I'm very happy today.", 1),
("This is a negative sentence. Everything was wrong today at work.", 0)]

tuples = [
("This is a positive example. I'm very happy today.", 1),
("This is a negative sentence. Everything was wrong today at work.", 0)
]
df = pd.DataFrame(tuples)
classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=64, labels_no=2)

classifier = SentenceClassifier(
model_name=Models.BertBaseUncased,
max_length=64,
labels_no=2
)
classifier.load_dataset(df, validation_split=0.2)
classifier.fine_tune(epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64)
classifier.fine_tune(
epochs=4,
learning_rate=2e-5,
training_batch_size=32,
validation_batch_size=64
)
```

# Prediction
Expand Down Expand Up @@ -76,18 +87,30 @@ If the length in tokens of the texts is greater than the `max_length` with which
from ernie import SplitStrategies, AggregationStrategies

texts = ["Oh, that's great!", "That's really bad"]
probabilities = classifier.predict(texts,
split_strategy=SplitStrategies.GroupedSentencesWithoutUrls,
aggregation_strategy=AggregationStrategies.Mean)
probabilities = classifier.predict(
texts,
split_strategy=SplitStrategies.GroupedSentencesWithoutUrls,
aggregation_strategy=AggregationStrategies.Mean
)
```


You can define your custom strategies through `AggregationStrategy` and `SplitStrategy` classes.
```python
from ernie import SplitStrategy, AggregationStrategy

my_split_strategy = SplitStrategy(split_patterns: list, remove_patterns: list, remove_too_short_groups: bool, group_splits: bool)
my_aggregation_strategy = AggregationStrategy(method: function, max_items: int, top_items: bool, sorting_class_index: int)
my_split_strategy = SplitStrategy(
split_patterns: list,
remove_patterns: list,
remove_too_short_groups: bool,
group_splits: bool
)
my_aggregation_strategy = AggregationStrategy(
method: function,
max_items: int,
top_items: bool,
sorting_class_index: int
)
```

# Save and restore a fine-tuned model
Expand All @@ -105,15 +128,18 @@ classifier = SentenceClassifier(model_path='./model')
Since the execution may break during training (especially if you are using Google Colab), you can opt to secure every new trained epoch, so the training can be resumed without losing all the progress.

```python
classifier = SentenceClassifier(model_name=Models.BertBaseUncased, max_length=64)
classifier = SentenceClassifier(
model_name=Models.BertBaseUncased,
max_length=64
)
classifier.load_dataset(df, validation_split=0.2)

for epoch in range(1, 5):
if epoch == 3:
raise Exception("Forced crash")
if epoch == 3:
raise Exception("Forced crash")

classifier.fine_tune(epochs=1)
classifier.dump(f'./my-model/{epoch}')
classifier.fine_tune(epochs=1)
classifier.dump(f'./my-model/{epoch}')
```

```python
Expand All @@ -123,8 +149,8 @@ classifier = SentenceClassifier(model_path=f'./my-model/{last_training_epoch}')
classifier.load_dataset(df, validation_split=0.2)

for epoch in range(last_training_epoch + 1, 5):
classifier.fine_tune(epochs=1)
classifier.dump(f'./my-model/{epoch}')
classifier.fine_tune(epochs=1)
classifier.dump(f'./my-model/{epoch}')
```

# Autosave
Expand Down
17 changes: 13 additions & 4 deletions ernie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .ernie import *
from .ernie import * # noqa: F401, F403
from tensorflow.python.client import device_lib
import logging

__version__ = '0.0.33b0'
__version__ = '1.0.0'

logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
logging.basicConfig(format='%(asctime)-15s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logging.basicConfig(
format='%(asctime)-15s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)


def _get_cpu_name():
Expand All @@ -20,7 +23,13 @@ def _get_cpu_name():


def _get_gpu_name():
gpu_name = device_lib.list_local_devices()[3].physical_device_desc.split(',')[1].split('name:')[1].strip()
gpu_name = \
device_lib\
.list_local_devices()[3]\
.physical_device_desc\
.split(',')[1]\
.split('name:')[1]\
.strip()
return gpu_name


Expand Down
57 changes: 38 additions & 19 deletions ernie/aggregation_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@


class AggregationStrategy:
def __init__(self, method, max_items=None, top_items=True, sorting_class_index=1):
def __init__(
self,
method,
max_items=None,
top_items=True,
sorting_class_index=1
):
self.method = method
self.max_items = max_items
self.top_items = top_items
Expand All @@ -20,32 +26,45 @@ def aggregate(self, softmax_tuples):
softmax_dicts.append(softmax_dict)

if self.max_items is not None:
softmax_dicts = sorted(softmax_dicts, key=lambda x: x[self.sorting_class_index], reverse=self.top_items)
softmax_dicts = sorted(
softmax_dicts,
key=lambda x: x[self.sorting_class_index],
reverse=self.top_items
)
if self.max_items < len(softmax_dicts):
softmax_dicts = softmax_dicts[:self.max_items]

softmax_list = []
for key in softmax_dicts[0].keys():
softmax_list.append(self.method([probabilities[key] for probabilities in softmax_dicts]))
softmax_list.append(self.method(
[probabilities[key] for probabilities in softmax_dicts]))
softmax_tuple = tuple(softmax_list)
return softmax_tuple


class AggregationStrategies:
Mean = AggregationStrategy(method=mean)
MeanTopFiveBinaryClassification = AggregationStrategy(method=mean,
max_items=5,
top_items=True,
sorting_class_index=1)
MeanTopTenBinaryClassification = AggregationStrategy(method=mean,
max_items=10,
top_items=True,
sorting_class_index=1)
MeanTopFifteenBinaryClassification = AggregationStrategy(method=mean,
max_items=15,
top_items=True,
sorting_class_index=1)
MeanTopTwentyBinaryClassification = AggregationStrategy(method=mean,
max_items=20,
top_items=True,
sorting_class_index=1)
MeanTopFiveBinaryClassification = AggregationStrategy(
method=mean,
max_items=5,
top_items=True,
sorting_class_index=1
)
MeanTopTenBinaryClassification = AggregationStrategy(
method=mean,
max_items=10,
top_items=True,
sorting_class_index=1
)
MeanTopFifteenBinaryClassification = AggregationStrategy(
method=mean,
max_items=15,
top_items=True,
sorting_class_index=1
)
MeanTopTwentyBinaryClassification = AggregationStrategy(
method=mean,
max_items=20,
top_items=True,
sorting_class_index=1
)
Loading

0 comments on commit 172f72f

Please sign in to comment.