Skip to content

Commit

Permalink
Merge pull request #1 from Sagacify/embedding_metrics
Browse files Browse the repository at this point in the history
feat(embedding_metrics): implement BERTScore and MAUVE and add unit t…
  • Loading branch information
LucieNvz authored Oct 18, 2023
2 parents 6e256c5 + db39ad3 commit 77ed4d2
Show file tree
Hide file tree
Showing 7 changed files with 1,929 additions and 65 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ jobs:
- name: Run formatter
run: |
poetry run black --check saga_predictor tests
poetry run black --check saga_llm_evaluation_ml tests
- name: Run linter
run: |
poetry run pylint saga_predictor tests
poetry run pylint saga_llm_evaluation_ml tests
1,851 changes: 1,788 additions & 63 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ python = "^3.8,<3.11"
transformers = "^4.21.1"
numpy = "^1.21.2"
spacy = "^3.1.3"
evaluate = "^0.4.1"
scikit-learn = "^1.3.1"
mauve-text = "^0.3.0"
bert-score = "^0.3.13"
torch = ">=2.0.0, !=2.0.1, !=2.1.0"

[tool.poetry.dev-dependencies]
pylint = "^2.13"
Expand All @@ -23,6 +28,10 @@ url = "https://pypiserver.sagacify.com/"
default = false
secondary = true


[tool.poetry.group.dev.dependencies]
pytest = "^7.4.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down
64 changes: 64 additions & 0 deletions saga_llm_evaluation_ml/model/helpers/embedding_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# TODO: Implement BERTScore
# TODO: Implement MAUVE
from evaluate import load


class BERTScore:
def __init__(self, model_type="distilbert-base-uncased"):
"""
BERTScore computes a similarity score for each token in the candidate sentence with each token in the reference sentence.
The final score is the average of the similarity scores of all tokens in the candidate sentence.
Args:
model_type (str, optional): Model type to use. Defaults to "roberta-large".
"""
self.model_type = model_type
self.metric = load("bertscore")

def compute(self, references, predictions, **kwargs):
"""
Args:
references (list): List of reference sentences.
predictions (list): List of candidate sentences.
Returns:
list: List of scores for each candidate sentence. Contains a list of scores for precisions, recalls, and F1 scores.
"""
assert len(references) == len(
predictions
), "Number of references and predictions must be equal."
assert isinstance(references, list), "References must be a list."
assert isinstance(predictions, list), "Predictions must be a list."

return self.metric.compute(
predictions=predictions,
references=references,
model_type=self.model_type,
**kwargs
)


class MAUVE:
def __init__(self, featurize_model_name="gpt2"):
"""
MAUVE score computes the difference between the candidate sentence distribution and the reference sentence distribution.
The bigger the MAUVE score, the better.
"""
self.metric = load("mauve")
self.featurize_model_name = featurize_model_name

def compute(self, references, predictions, **kwargs):
"""
Args:
references (list): List of reference sentences.
predictions (list): List of candidate sentences.
Returns:
list: List of MAUVE scores for each candidate sentence.
"""
return self.metric.compute(
predictions=predictions,
references=references,
featurize_model_name=self.featurize_model_name,
**kwargs
)
8 changes: 8 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os
import sys

MODULE_ROOT = os.path.abspath("/www/app/src")
sys.path.append(MODULE_ROOT)

PROJ_ROOT = os.path.abspath("/www/app")
sys.path.append(PROJ_ROOT)
Empty file removed tests/test.py
Empty file.
58 changes: 58 additions & 0 deletions tests/test_embedding_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest

from saga_llm_evaluation_ml.model.helpers.embedding_metrics import BERTScore, MAUVE


class TestBERTScore(unittest.TestCase):
def test_compute(self):
"""Tests that the BERTScore class computes the correct scores. And that the scores are the same when the same inputs are given."""
references = ["The cat sat on the mat.", "The dog sat on the log."]
predictions = ["The cat sat on the mat.", "The dog sat on the log."]
bertscore = BERTScore()
scores = bertscore.compute(references, predictions)
print(scores)
self.assertEqual(len(scores["precision"]), len(references))
self.assertEqual(len(scores["recall"]), len(references))
self.assertEqual(len(scores["f1"]), len(references))

scores_2 = bertscore.compute(references, predictions)
self.assertEqual(scores["precision"], scores_2["precision"])
self.assertEqual(scores["recall"], scores_2["recall"])
self.assertEqual(scores["f1"], scores_2["f1"])

def test_compute_improved_input(self):
"""Tests that the BERTScore improves for a better prediction."""
reference = "The cat sat on the mat."
prediction = "The dog sat on the mat."
better_prediction = "The cat sat on the mat."

bertscore = BERTScore()

scores = bertscore.compute([reference], [prediction])
better_scores = bertscore.compute([reference], [better_prediction])

self.assertGreater(better_scores["f1"][0], scores["f1"][0])


class TestMAUVE(unittest.TestCase):
def test_compute(self):
"""Tests that the MAUVE class computes the same scores when the same inputs are given."""
mauve = MAUVE()
references = ["The cat sat on the mat.", "The dog sat on the log."]
predictions = ["The cat sat on the mat.", "The dog sat on the log."]
scores = mauve.compute(references, predictions)
scores_2 = mauve.compute(references, predictions)
self.assertEqual(scores.mauve, scores_2.mauve)

def test_compute_improved_input(self):
"""Tests that the MAUVE Score improves for a better prediction."""
reference = "The cat sat on the mat."
prediction = "The dog sat on the mat."
better_prediction = "The cat sat on the mat."

mauve = MAUVE()

scores = mauve.compute([reference], [prediction])
better_scores = mauve.compute([reference], [better_prediction])

self.assertGreater(better_scores.mauve, scores.mauve)

0 comments on commit 77ed4d2

Please sign in to comment.