diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 1445816..5ed2394 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -4,7 +4,6 @@ import re import sys import json -import copy import math from functools import partial from typing import Any, Dict, Optional, List, Union, Tuple, Callable @@ -1213,8 +1212,6 @@ def __init__(self, pooling_strategy=self.pooling_strategy, padding_strategy=self.tokenizer.padding_side) - # full_backbone is used to Espresso inference - self.full_backbone = None self.__cfg = { 'model_name_or_path': model_name_or_path, 'max_length': max_length, @@ -1474,12 +1471,22 @@ def evaluate(self, data: Dataset, batch_size: int = 32, metric: str = 'spearman_ batch_size=batch_size, )(self)[metric] + def truncate_layer(self, layer_index: int): + """ truncate layer + + :param layer_index: int. layers after layer_index will be truncated. + :return: self + """ + if len(self.backbone.encoder.layer) < layer_index: + logger.info('current layer_index is larger than the number of layers, please check whether it is correct') + self.backbone.encoder.layer = self.backbone.encoder.layer[:layer_index] + return self + def encode(self, inputs: Union[List[str], Tuple[str], List[Dict], str], max_length: Optional[int] = None, end_with_eos: bool = False, to_numpy: bool = True, - layer_index: int = -1, embedding_start: int = 0, embedding_size: Optional[int] = None, device: Optional[Any] = None, @@ -1491,7 +1498,6 @@ def encode(self, :param inputs: Union[List[str], Tuple[str], List[Dict], str]. Input texts. Required. :param max_length: Optional[int]. Default None. :param to_numpy: bool. Default True. - :param layer_index: int. Obtain specific layer's sentence embeddings (for Espresso). :param embedding_start: int. Specify the start position of the embedding (for Espresso). :param embedding_size: Optional[int]. Specify embedding size (for Espresso). The embeddings from embedding_start to embedding_start+embedding_size will be returned. @@ -1499,15 +1505,10 @@ def encode(self, :param prompt: Optional[str]. Default None. :param normalize_embedding: bool. Default False. """ - if layer_index != -1 and self.full_backbone is None: - self.full_backbone = copy.deepcopy(self.backbone) - - if layer_index != -1: - self.backbone.encoder.layer = self.full_backbone.encoder.layer[:layer_index] + self.backbone.eval() if device is None: device = self.device - self.backbone.eval() if not isinstance(inputs, (tuple, list)): inputs = [inputs] if prompt is not None: @@ -1537,7 +1538,6 @@ def encode(self, tok.to(device) with torch.no_grad(): output = self.pooler(tok, - layer_index=layer_index, embedding_start=embedding_start, embedding_size=embedding_size) if normalize_embedding: diff --git a/tests/test_loadding.py b/tests/test_loadding.py index 4686410..216942e 100644 --- a/tests/test_loadding.py +++ b/tests/test_loadding.py @@ -16,14 +16,14 @@ def test_loadding(): assert isinstance(vecs, np.ndarray) -def test_2dmse_loadding(): +def test_ese_loadding(): import numpy as np from angle_emb import AnglE - angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1') - vecs = angle.encode('hello world', layer_index=20) + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').truncate_layer(20) + vecs = angle.encode('hello world') assert isinstance(vecs, np.ndarray) - vecs = angle.encode(['hello world', 'hi theređź‘‹'], layer_index=20, embedding_size=512) + vecs = angle.encode(['hello world', 'hi theređź‘‹'], embedding_size=512) assert isinstance(vecs, np.ndarray)