Skip to content

Commit

Permalink
fix layer index (#93)
Browse files Browse the repository at this point in the history
* specify layer_index via truncate_layer

* update doc

* fix test cases
  • Loading branch information
SeanLee97 authored Jul 28, 2024
1 parent 5f54c15 commit d445356
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
import sys
import json
import copy
import math
from functools import partial
from typing import Any, Dict, Optional, List, Union, Tuple, Callable
Expand Down Expand Up @@ -1213,8 +1212,6 @@ def __init__(self,
pooling_strategy=self.pooling_strategy,
padding_strategy=self.tokenizer.padding_side)

# full_backbone is used to Espresso inference
self.full_backbone = None
self.__cfg = {
'model_name_or_path': model_name_or_path,
'max_length': max_length,
Expand Down Expand Up @@ -1474,12 +1471,22 @@ def evaluate(self, data: Dataset, batch_size: int = 32, metric: str = 'spearman_
batch_size=batch_size,
)(self)[metric]

def truncate_layer(self, layer_index: int):
""" truncate layer
:param layer_index: int. layers after layer_index will be truncated.
:return: self
"""
if len(self.backbone.encoder.layer) < layer_index:
logger.info('current layer_index is larger than the number of layers, please check whether it is correct')
self.backbone.encoder.layer = self.backbone.encoder.layer[:layer_index]
return self

def encode(self,
inputs: Union[List[str], Tuple[str], List[Dict], str],
max_length: Optional[int] = None,
end_with_eos: bool = False,
to_numpy: bool = True,
layer_index: int = -1,
embedding_start: int = 0,
embedding_size: Optional[int] = None,
device: Optional[Any] = None,
Expand All @@ -1491,23 +1498,17 @@ def encode(self,
:param inputs: Union[List[str], Tuple[str], List[Dict], str]. Input texts. Required.
:param max_length: Optional[int]. Default None.
:param to_numpy: bool. Default True.
:param layer_index: int. Obtain specific layer's sentence embeddings (for Espresso).
:param embedding_start: int. Specify the start position of the embedding (for Espresso).
:param embedding_size: Optional[int]. Specify embedding size (for Espresso).
The embeddings from embedding_start to embedding_start+embedding_size will be returned.
:param device: Optional[Any]. Default None.
:param prompt: Optional[str]. Default None.
:param normalize_embedding: bool. Default False.
"""
if layer_index != -1 and self.full_backbone is None:
self.full_backbone = copy.deepcopy(self.backbone)

if layer_index != -1:
self.backbone.encoder.layer = self.full_backbone.encoder.layer[:layer_index]
self.backbone.eval()

if device is None:
device = self.device
self.backbone.eval()
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
if prompt is not None:
Expand Down Expand Up @@ -1537,7 +1538,6 @@ def encode(self,
tok.to(device)
with torch.no_grad():
output = self.pooler(tok,
layer_index=layer_index,
embedding_start=embedding_start,
embedding_size=embedding_size)
if normalize_embedding:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_loadding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def test_loadding():
assert isinstance(vecs, np.ndarray)


def test_2dmse_loadding():
def test_ese_loadding():
import numpy as np
from angle_emb import AnglE

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1')
vecs = angle.encode('hello world', layer_index=20)
angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').truncate_layer(20)
vecs = angle.encode('hello world')
assert isinstance(vecs, np.ndarray)
vecs = angle.encode(['hello world', 'hi there👋'], layer_index=20, embedding_size=512)
vecs = angle.encode(['hello world', 'hi there👋'], embedding_size=512)
assert isinstance(vecs, np.ndarray)


Expand Down

0 comments on commit d445356

Please sign in to comment.