Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor evaluation #90

Merged
merged 10 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ ds = ds.select_columns(["text1", "text2", "label"])
# 3. transform data
train_ds = ds['train'].shuffle().map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)
valid_ds = ds['validation'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)
test_ds = ds['test'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)

# 4. fit
angle.fit(
Expand All @@ -325,8 +324,8 @@ angle.fit(
)

# 5. evaluate
corrcoef, accuracy = angle.evaluate(test_ds, device=angle.device)
print('corrcoef:', corrcoef)
corrcoef = angle.evaluate(ds['test'])
print('Spearman\'s corrcoef:', corrcoef)
```

### 💡 Others
Expand Down
3 changes: 2 additions & 1 deletion angle_emb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from .angle import *
from .angle import * # NOQA
from .evaluation import * # NOQA


__version__ = '0.4.8'
97 changes: 21 additions & 76 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@
from typing import Any, Dict, Optional, List, Union, Tuple, Callable
from dataclasses import dataclass

import scipy
import scipy.stats
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
from tqdm import tqdm
from boltons.iterutils import chunked_iter
from datasets import Dataset
from transformers import (
AutoModelForCausalLM, AutoModel, AutoTokenizer,
Expand All @@ -34,7 +29,9 @@
)
from peft.tuners.lora import LoraLayer

from .base import AngleBase
from .utils import logger
from .evaluation import CorrelationEvaluator


DEFAULT_LLM_PATTERNS = [r'.*llama.*', r'.*qwen.*', r'.*baichuan.*', r'.*mistral.*']
Expand Down Expand Up @@ -237,44 +234,6 @@ def contrastive_with_negative_loss(
return nn.CrossEntropyLoss()(scores, labels)


def compute_corrcoef(x: np.ndarray, y: np.ndarray) -> float:
"""
Compute correlation coefficients

:param x: np.ndarry, x array
:param y: np.ndarry, y array

:return: float
"""
return scipy.stats.spearmanr(x, y).correlation


def l2_normalize(arr: np.ndarray) -> np.ndarray:
"""
Normalize array using L2

:param arr: np.ndarray, input array

:return: np.ndarray
"""
norms = (arr**2).sum(axis=1, keepdims=True)**0.5
return arr / np.clip(norms, 1e-8, np.inf)


def optimal_threshold(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float]:
"""
Compute optimal threshold

:param y_true: np.ndarray, y_true
:param y_pred: np.ndarray, y_true

:return: Tuple[float, float]
"""
loss = lambda t: -np.mean((y_true > 0.5) == (y_pred > np.tanh(t))) # NOQA
result = scipy.optimize.minimize(loss, 1, method='Powell')
return np.tanh(result.x), -result.fun


def check_llm(model_name_or_path: str, llm_regex_patterns: List[str] = None) -> bool:
if llm_regex_patterns is not None:
llm_regex_patterns += DEFAULT_LLM_PATTERNS
Expand Down Expand Up @@ -1036,7 +995,7 @@ def __call__(self,
return loss


class AnglE:
class AnglE(AngleBase):
"""
AnglE. Everything is here👋

Expand Down Expand Up @@ -1449,7 +1408,7 @@ def fit(self,
if output_dir is not None:
best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint')
evaluate_callback = EvaluateCallback(self, valid_ds,
partial(self.evaluate, batch_size=batch_size, device=self.device),
partial(self.evaluate, batch_size=batch_size),
save_dir=best_ckpt_dir,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
Expand Down Expand Up @@ -1499,35 +1458,21 @@ def fit(self,
trainer.push_to_hub()
self.backbone.save_pretrained(output_dir)

def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[float] = None, device: Any = None):
self.backbone.eval()
data_collator = AngleDataCollator(
self.tokenizer,
return_tensors="pt",
max_length=self.max_length,
filter_duplicate=False,
)
y_trues, y_preds = [], []
# for X, y in data.make_iter(random=False):
for features in tqdm(chunked_iter(data, batch_size), desc='Evaluate'):
X = data_collator(features)
y = X.pop('labels', None)
y_trues.extend(y[::2, 0].detach().cpu().numpy())
with torch.no_grad():
X.to(device or self.device)
x_vecs = self.pooler(X,
pooling_strategy=self.pooling_strategy).detach().float().cpu().numpy()
x_vecs = l2_normalize(x_vecs)
pred = (x_vecs[::2] * x_vecs[1::2]).sum(1)
y_preds.extend(pred)

y_trues, y_preds = np.array(y_trues), np.array(y_preds)
corrcoef = compute_corrcoef(y_trues, y_preds)
if threshold is None:
_, accuracy = optimal_threshold(y_trues, y_preds)
else:
accuracy = np.mean((y_trues > 0.5) == (y_preds > threshold))
return corrcoef, accuracy
def evaluate(self, data: Dataset, batch_size: int = 32, metric: str = 'spearman_cosine') -> float:
""" evaluate

:param data: Dataset, DatasetFormats.A is required
:param batch_size: int. Default 32.
:param metric: str. Default 'spearman_cosine'.

:return: float.
"""
return CorrelationEvaluator(
text1=data['text1'],
text2=data['text2'],
labels=data['label'],
batch_size=batch_size,
)(self)[metric]

def encode(self,
inputs: Union[List[str], Tuple[str], List[Dict], str],
Expand Down Expand Up @@ -1656,7 +1601,7 @@ def __init__(self,
self.hub_private_repo = hub_private_repo

def on_epoch_end(self, args, state, control, **kwargs):
corrcoef, accuracy = self.evaluate_fn(self.valid_ds)
corrcoef = self.evaluate_fn(self.valid_ds)
if corrcoef > self.best_corrcoef:
self.best_corrcoef = corrcoef
print('new best corrcoef!')
Expand All @@ -1669,4 +1614,4 @@ def on_epoch_end(self, args, state, control, **kwargs):
private=self.hub_private_repo,
exist_ok=True,
commit_message='new best checkpoint')
print(f'corrcoef: {corrcoef}, accuracy: {accuracy}, best corrcoef: {self.best_corrcoef}')
logger.info(f'corrcoef: {corrcoef}, best corrcoef: {self.best_corrcoef}')
12 changes: 12 additions & 0 deletions angle_emb/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABCMeta, abstractmethod


class AngleBase(metaclass=ABCMeta):

@abstractmethod
def encode(self):
raise NotImplementedError

@abstractmethod
def fit(self):
raise NotImplementedError
97 changes: 97 additions & 0 deletions angle_emb/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-

from typing import List

import numpy as np
from boltons.iterutils import chunked_iter
from tqdm import tqdm
from sklearn.metrics.pairwise import (
paired_cosine_distances,
paired_euclidean_distances,
paired_manhattan_distances
)
from scipy.stats import pearsonr, spearmanr

from .base import AngleBase


class CorrelationEvaluator(object):
def __init__(
self,
text1: List[str],
text2: List[str],
labels: List[float],
batch_size: int = 32
):
assert len(text1) == len(text2) == len(labels), "text1, text2, and labels must have the same length"
self.text1 = text1
self.text2 = text2
self.labels = labels
self.batch_size = batch_size

def __call__(self, model: AngleBase, **kwargs) -> dict:
""" Evaluate the model on the given dataset.

:param model: AnglE, the model to evaluate.
:param kwargs: Additional keyword arguments to pass to the `encode` method of the model.

:return: dict, The evaluation results.
"""
embeddings1 = []
embeddings2 = []
for chunk in tqdm(chunked_iter(range(len(self.text1)), self.batch_size)):
batch_text1 = [self.text1[i] for i in chunk]
batch_text2 = [self.text2[i] for i in chunk]

batch_embeddings1 = model.encode(batch_text1, **kwargs)
batch_embeddings2 = model.encode(batch_text2, **kwargs)
embeddings1.append(batch_embeddings1)
embeddings2.append(batch_embeddings2)

embeddings1 = np.concatenate(embeddings1, axis=0)
embeddings2 = np.concatenate(embeddings2, axis=0)

cosine_labels = 1 - (paired_cosine_distances(embeddings1, embeddings2))
manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]

pearson_cosine, _ = pearsonr(self.labels, cosine_labels)
spearman_cosine, _ = spearmanr(self.labels, cosine_labels)

pearson_manhattan, _ = pearsonr(self.labels, manhattan_distances)
spearman_manhattan, _ = spearmanr(self.labels, manhattan_distances)

pearson_euclidean, _ = pearsonr(self.labels, euclidean_distances)
spearman_euclidean, _ = spearmanr(self.labels, euclidean_distances)

pearson_dot, _ = pearsonr(self.labels, dot_products)
spearman_dot, _ = spearmanr(self.labels, dot_products)

metrics = {
"pearson_cosine": pearson_cosine,
"spearman_cosine": spearman_cosine,
"pearson_manhattan": pearson_manhattan,
"spearman_manhattan": spearman_manhattan,
"pearson_euclidean": pearson_euclidean,
"spearman_euclidean": spearman_euclidean,
"pearson_dot": pearson_dot,
"spearman_dot": spearman_dot,
}
return metrics

def list_all_metrics(self) -> List[str]:
""" Get a list of all the metrics that can be computed by this evaluator.

:return: List[str], A list of all the metrics that can be computed by this evaluator.
"""
return [
"pearson_cosine",
"spearman_cosine",
"pearson_manhattan",
"spearman_manhattan",
"pearson_euclidean",
"spearman_euclidean",
"pearson_dot",
"spearman_dot",
]
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ AnglE is also a general sentence embedding inference framework, allowing for inf
notes/installation.rst
notes/pretrained_models.rst
notes/training.rst
notes/evaluation.rst
notes/citation.rst


Expand Down
54 changes: 54 additions & 0 deletions docs/notes/evaluation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
🎯 Evaluation
============================


🎯 Spearman and Pearson Correlation
-----------------------------------------

Spearman's and Pearson's correlation coefficients are commonly used to evaluate text embedding quality.

We provide two ways to evaluate the text embeddings by Spearman and Pearson Correlation

1) use `angle.evaluate(dataset)` function


.. code-block:: python

from angle_emb import AnglE
from datasets import load_dataset


angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').cuda()


ds = load_dataset('mteb/stsbenchmark-sts', split='test')
ds = ds.map(lambda obj: {"text1": str(obj["sentence1"]), "text2": str(obj['sentence2']), "label": obj['score']})
ds = ds.select_columns(["text1", "text2", "label"])

angle.evaluate(ds, metric='spearman_cosine')



2) use `angle_emb.CorrelationEvaluator` evaluator


.. code-block:: python

from angle_emb import AnglE, CorrelationEvaluator
from datasets import load_dataset


angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').cuda()

ds = load_dataset('mteb/stsbenchmark-sts', split='test')
ds = ds.map(lambda obj: {"text1": str(obj["sentence1"]), "text2": str(obj['sentence2']), "label": obj['score']})
ds = ds.select_columns(["text1", "text2", "label"])

metric = CorrelationEvaluator(
text1=ds['text1'],
text2=ds['text2'],
labels=ds['label']
)(angle)

print(metric)

2 changes: 1 addition & 1 deletion docs/notes/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ You can also train a sentence embedding model using the `angle_emb` library. Her
:alt: Open In Colab


💡 4. Fine-tuning Tips
💡 Fine-tuning Tips
-------------------------

1. If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ prettytable
transformers>=4.32.1
scipy
einops
wandb
wandb
scikit-learn
20 changes: 20 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-


def test_eval():
from datasets import load_dataset
from angle_emb import AnglE, CorrelationEvaluator

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls')
eval_dataset = load_dataset('sentence-transformers/stsb', split="test")

spearman = CorrelationEvaluator(
text1=eval_dataset["sentence1"],
text2=eval_dataset["sentence2"],
labels=eval_dataset["score"],
)(angle)['spearman_cosine']
assert spearman > 0.89

spearman = angle.evaluate(
eval_dataset.rename_columns({'sentence1': 'text1', 'sentence2': 'text2', 'score': 'label'}))
assert spearman > 0.89
Loading