diff --git a/angle_emb/angle.py b/angle_emb/angle.py index a641f6c..1cee203 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -29,6 +29,7 @@ ) from peft.tuners.lora import LoraLayer +from .base import AngleBase from .utils import logger from .evaluation import CorrelationEvaluator @@ -994,7 +995,7 @@ def __call__(self, return loss -class AnglE: +class AnglE(AngleBase): """ AnglE. Everything is heređź‘‹ diff --git a/angle_emb/base.py b/angle_emb/base.py new file mode 100644 index 0000000..6618da2 --- /dev/null +++ b/angle_emb/base.py @@ -0,0 +1,12 @@ +from abc import ABCMeta, abstractmethod + + +class AngleBase(metaclass=ABCMeta): + + @abstractmethod + def encode(self): + raise NotImplementedError + + @abstractmethod + def fit(self): + raise NotImplementedError diff --git a/angle_emb/evaluation.py b/angle_emb/evaluation.py index 75b1749..3d21201 100644 --- a/angle_emb/evaluation.py +++ b/angle_emb/evaluation.py @@ -12,6 +12,8 @@ ) from scipy.stats import pearsonr, spearmanr +from .base import AngleBase + class CorrelationEvaluator(object): def __init__( @@ -27,7 +29,7 @@ def __init__( self.labels = labels self.batch_size = batch_size - def __call__(self, model, **kwargs) -> dict: + def __call__(self, model: AngleBase, **kwargs) -> dict: """ Evaluate the model on the given dataset. :param model: AnglE, the model to evaluate.