From 830e7bc25822b03ac402e2d456e859af5e2560bb Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 24 Apr 2024 13:51:12 +0800 Subject: [PATCH 01/11] add espresso & improve code --- README.md | 164 +++---- README_2DMSE.md | 26 +- README_Espresso.md | 0 angle_emb/angle.py | 515 +++++++++++++-------- angle_emb/train_cli.py | 146 +++--- angle_emb/utils.py | 13 + docs/notes/installation.rst | 2 +- docs/notes/quick_start.rst | 4 +- scripts/convert_to_sentence_transformer.py | 32 ++ 9 files changed, 509 insertions(+), 393 deletions(-) create mode 100644 README_Espresso.md create mode 100644 scripts/convert_to_sentence_transformer.py diff --git a/README.md b/README.md index f2545b1..134132a 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,12 @@ > It is Angle ๐Ÿ“, not Angel ๐Ÿ‘ผ. -๐Ÿ”ฅ **A New SOTA** for Semantic Textual Similarity! +๐Ÿ“ข **Train/Infer Powerful Sentence Embedding Models with AnglE.** +AnglE enables you to train state-of-the-art BERT-based or LLM-based sentence embeddings with just a few lines of code. +AnglE is also a general inference framework for sentence embedding, allowing you to infer a variety of transformer-based sentence embeddings. -๐Ÿ”ฅ **Our universal sentence embedding [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of 64.64!** - +## ๐Ÿ† Achievements https://arxiv.org/abs/2309.12871 @@ -32,31 +33,27 @@ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts-benchmark)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts-benchmark?p=angle-optimized-text-embeddings) -
-๐Ÿ“Š Results on MTEB Leaderboard [click to expand] -

- -

-
+๐Ÿ“… Mar 13, 2024 | Paper "[BeLLM: Backward Dependency Enhanced Large Language Model for Sentence Embeddings](https://arxiv.org/abs/2311.05296)" accepted by NAACL 2024 Main Conference. -
-๐Ÿ“Š Results on STS benchmark [click to expand] -

- -

-
-## ๐Ÿค— Pretrained Models +๐Ÿ“… Mar 8, 2024 | ๐Ÿž [mixedbread's embedding](https://www.mixedbread.ai/blog/mxbai-embed-large-v1) ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of **64.68**! The model is trained using AnglE. + + +๐Ÿ“… Dec 4, 2023 | Our universal sentence embedding [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of **64.64**! The model is trained using AnglE. + + +๐Ÿ“… Dec, 2023 | **A New SOTA** for Semantic Textual Similarity! + + +## ๐Ÿค— Official Pretrained Models | ๐Ÿค— HF | LoRA Weight | Dependent Backbone | LLM | Language | Prompt | Pooling Strategy | Examples | |----|------|------|------|------|------|------|------| | [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | N | N | N | EN | `Prompts.C` for retrieval purposes, `None` for others | cls | [![Seach Demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WOYD6f8gb_wpkUm_57K8pEDgjlGJd6oB?usp=drive_link) | | [SeanLee97/angle-llama-13b-nli](https://huggingface.co/SeanLee97/angle-llama-13b-nli) | Y | NousResearch/Llama-2-13b-hf | Y | EN | `Prompts.A` | last token | / | | [SeanLee97/angle-llama-7b-nli-v2](https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2) | Y | NousResearch/Llama-2-7b-hf | Y | EN | `Prompts.A` | last token | / | -| [SeanLee97/angle-llama-7b-nli-20231027](https://huggingface.co/SeanLee97/angle-llama-7b-nli-20231027) | Y | NousResearch/Llama-2-7b-hf | Y | EN | `Prompts.A` | last token | / | | [SeanLee97/angle-bert-base-uncased-nli-en-v1](https://huggingface.co/SeanLee97/angle-bert-base-uncased-nli-en-v1) | N | N | N | EN | N | `cls_avg` | / | -| [SeanLee97/angle-roberta-wwm-base-zhnli-v1](https://huggingface.co/SeanLee97/angle-roberta-wwm-base-zhnli-v1) | N | N | N | ZH-CN | N | `cls` | / | -| [SeanLee97/angle-llama-7b-zhnli-v1](https://huggingface.co/SeanLee97/angle-llama-7b-zhnli-v1) | Y | NousResearch/Llama-2-7b-hf | Y | ZH-CN | `Prompts.B` | last token | / | +
๐Ÿ’ก Tips ๐Ÿ’ก If the selected model is a LoRA weight, it must specify the corresponding dependent backbone. For our STS Experiment, please refer to https://github.com/SeanLee97/AnglE/tree/main/examples/NLI @@ -75,19 +72,7 @@ For our STS Experiment, please refer to https://github.com/SeanLee97/AnglE/tree/ | [SeanLee97/angle-bert-base-uncased-nli-en-v1](https://huggingface.co/SeanLee97/angle-bert-base-uncased-nli-en-v1) | 75.09 | 85.56 | 80.66 | 86.44 | 82.47 | 85.16 | 81.23 | 82.37 | -### Chinese STS Results - -| Model | ATEC | BQ | LCQMC | PAWSX | STS-B | SOHU-dd | SOHU-dc | Avg. | -| ------- |-------|-------|-------|-------|-------|--------------|-----------------|-------| -| ^[shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese) | 38.41 | 61.34 | 71.72 | 35.15 | 76.44 | 71.81 | 63.15 | 59.72 | -| ^[shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase) | 44.89 | 63.58 | 74.24 | 40.90 | 78.93 | 76.70 | 63.30 | 63.08 | -| [SeanLee97/angle-roberta-wwm-base-zhnli-v1](https://huggingface.co/SeanLee97/angle-roberta-wwm-base-zhnli-v1) | 49.49 | 72.47 | 78.33 | 59.13 | 77.14 | 72.36 | 60.53 | **67.06** | -| [SeanLee97/angle-llama-7b-zhnli-v1](https://huggingface.co/SeanLee97/angle-llama-7b-zhnli-v1) | 50.44 | 71.95 | 78.90 | 56.57 | 81.11 | 68.11 | 52.02 | 65.59 | - -^ denotes baselines, their results are retrieved from: https://github.com/shibing624/text2vec - - -## Usage +## ๐Ÿš€ Quick Start AnglE supports two APIs, one is the `transformers` API, the other is the `AnglE` API. If you want to use the `AnglE` API, please install AnglE first: @@ -95,33 +80,47 @@ AnglE supports two APIs, one is the `transformers` API, the other is the `AnglE` python -m pip install -U angle-emb ``` -### UAE +### 1. Load BERT-based Models 1) For Retrieval Purposes -For retrieval purposes, please use the prompt `Prompts.C` for the query (โš ๏ธ๏ผšno need to apply prompt for documents). +For retrieval purposes, please use the prompt `Prompts.C` for query (not document). ```python from angle_emb import AnglE, Prompts +from angle_emb.utils import cosine_similarity + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() -angle.set_prompt(prompt=Prompts.C) -vec = angle.encode({'text': 'hello world'}, to_numpy=True) -print(vec) -vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True) -print(vecs) +# when specify prompt, the inputs should be a list of dict with key 'text' +qv = angle.encode({'text': 'what is the weather?'}, to_numpy=True, prompt=Prompts.C) +doc_vecs = angle.encode([ + 'The weather is great!', + 'it is rainy today.', + 'i am going to bed' +], to_numpy=True) + +for dv in doc_vecs: + print(cosine_similarity(qv[0], dv)) ``` 2) For non-Retrieval Purposes ```python from angle_emb import AnglE +from angle_emb.utils import cosine_similarity + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() -vec = angle.encode('hello world', to_numpy=True) -print(vec) -vecs = angle.encode(['hello world1', 'hello world2'], to_numpy=True) -print(vecs) +doc_vecs = angle.encode([ + 'The weather is great!', + 'The weather is very good!', + 'i am going to bed' +]) + +for i, dv1 in enumerate(doc_vecs): + for dv2 in doc_vecs[i+1:]: + print(cosine_similarity(dv1, dv2)) ```
@@ -146,78 +145,37 @@ For non-retrieval tasks, we set the prompt to empty, i.e., just input your text So, if your scenario is retrieval-related, it is highly recommended to set the prompt with angle.set_prompt(prompt=Prompts.C). If not, leave the prompt empty or use angle.set_prompt(prompt=None).
-### Angle-LLaMA +### 2. Load LoRA-based Models 1) AnglE ```python from angle_emb import AnglE, Prompts -angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', pretrained_lora_path='SeanLee97/angle-llama-7b-nli-v2') +angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', + pretrained_lora_path='SeanLee97/angle-llama-7b-nli-v2', + pooling_strategy='last') print('All predefined prompts:', Prompts.list_prompts()) -angle.set_prompt(prompt=Prompts.A) -print('prompt:', angle.prompt) -vec = angle.encode({'text': 'hello world'}, to_numpy=True) +vec = angle.encode({'text': 'hello world'}, to_numpy=True, prompt=Prompts.A) print(vec) -vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True) +vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True, prompt=Prompts.A) print(vecs) ``` -2) transformers +### 3. Load Third-party Models w/ angle_emb -```python -from angle_emb import AnglE, Prompts -from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import PeftModel, PeftConfig - -peft_model_id = 'SeanLee97/angle-llama-7b-nli-v2' -config = PeftConfig.from_pretrained(peft_model_id) -tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) -model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path).bfloat16().cuda() -model = PeftModel.from_pretrained(model, peft_model_id).cuda() - -def decorate_text(text: str): - return Prompts.A.format(text=text) - -inputs = 'hello world!' -tok = tokenizer([decorate_text(inputs)], return_tensors='pt') -for k, v in tok.items(): - tok[k] = v.cuda() -vec = model(output_hidden_states=True, **tok).hidden_states[-1][:, -1].float().detach().cpu().numpy() -print(vec) -``` +You can load any transformer-based third-party models such as `mixedbread-ai/mxbai-embed-large-v1`, `sentence-transformers/all-MiniLM-L6-v2`, and `BAAI/bge-large-en-v1.5` using `angle_emb`. -### Angle-BERT +Here is an example: -1) AnglE ```python from angle_emb import AnglE -angle = AnglE.from_pretrained('SeanLee97/angle-bert-base-uncased-nli-en-v1', pooling_strategy='cls_avg').cuda() -vec = angle.encode('hello world', to_numpy=True) +model = AnglE.from_pretrained('mixedbread-ai/mxbai-embed-large-v1', pooling_strategy='cls').cuda() +vec = model.encode('hello world', to_numpy=True) print(vec) -vecs = angle.encode(['hello world1', 'hello world2'], to_numpy=True) -print(vecs) ``` -2) transformers - -```python -import torch -from transformers import AutoModel, AutoTokenizer - -model_id = 'SeanLee97/angle-bert-base-uncased-nli-en-v1' -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModel.from_pretrained(model_id).cuda() - -inputs = 'hello world!' -tok = tokenizer([inputs], return_tensors='pt') -for k, v in tok.items(): - tok[k] = v.cuda() -hidden_state = model(**tok).last_hidden_state -vec = (hidden_state[:, 0] + torch.mean(hidden_state, dim=1)) / 2.0 -print(vec) -``` ## Custom Train @@ -275,12 +233,12 @@ angle.fit( warmup_steps=0, gradient_accumulation_steps=1, loss_kwargs={ - 'w1': 1.0, - 'w2': 1.0, - 'w3': 1.0, + 'cosine_w': 1.0, + 'ibn_w': 1.0, + 'angle_w': 1.0, 'cosine_tau': 20, 'ibn_tau': 20, - 'angle_tau': 1.0 + 'angle_tau': 20.0 }, fp16=True, logging_steps=100 @@ -293,13 +251,13 @@ print('corrcoef:', corrcoef) ### 4. Fine-tuning Tips ๐Ÿ’ก -1) if your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `w1` or slightly decrease the weight for `w2`. +1) if your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. -2) if your dataset format is `DatasetFormats.B`, it is recommended to set `w1` to 0, and increase the weight for `w2` such as 10 and 20. The `angle_tau` can be set to 20.0. +2) if your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` can be set to 20.0. -3) if your dataset format is `DatasetFormats.C`, only `w2` and `ibn_tau` are effective. You don't need to tune other parameters. +3) if your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. -4) To alleviate information forgetting in fine-tuning, it is better to specify the `fixed_teacher_name_or_path`. If the `fixed_teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **It is worth to note that** `fixed_teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results. +4) To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **It is worth to note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results. # Citation diff --git a/README_2DMSE.md b/README_2DMSE.md index 8982624..4d7ccb8 100644 --- a/README_2DMSE.md +++ b/README_2DMSE.md @@ -2,31 +2,9 @@ > Paper: https://arxiv.org/abs/2402.14776 -# Usage +"๐Ÿช† 2D Matryoshka Sentence Embeddings" has been renamed to "โ˜•๏ธ Espresso Sentence Embeddings". -**โš ๏ธ The Document is Working in Progress!** - - -Example: - -```bash -WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=0 angle-trainer \ ---model_name_or_path WhereIsAI/UAE-Large-V1 \ ---train_name_or_path data.jsonl --save_dir ckpts/custom-UAE-2dmse \ ---w2 20.0 --w1 1. --w3 1. --angle_tau 20.0 --learning_rate 1e-5 --maxlen 128 \ ---workers 16 \ ---pooling_strategy all \ ---epochs 1 \ ---batch_size 16 \ ---apply_tdmse 1 \ ---fixed_teacher_name_or_path WhereIsAI/UAE-Large-V1 \ ---logging_steps 1000 \ ---warmup_steps 100 \ ---is_llm 0 \ ---save_steps 1000 --seed -1 --gradient_accumulation_steps 6 --fp16 1 -``` - -The `--apply_tdmse 1` is required. +Please find the document in [โ˜•๏ธ Espresso](README_Espresso.md) # Citation diff --git a/README_Espresso.md b/README_Espresso.md new file mode 100644 index 0000000..e69de29 diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6764059..e2ccada 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -5,7 +5,7 @@ import sys import json import copy -import random +import math from functools import partial from typing import Any, Dict, Optional, List, Union, Tuple, Callable from dataclasses import dataclass @@ -29,7 +29,7 @@ from transformers.utils import PaddingStrategy from peft import ( get_peft_model, LoraConfig, TaskType, PeftModel, - prepare_model_for_kbit_training + prepare_model_for_kbit_training, ) from peft.tuners.lora import LoraLayer @@ -113,7 +113,7 @@ def cosine_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 20.0) - return torch.logsumexp(y_pred, dim=0) -def angle_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 1.0): +def angle_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 1.0, pooling_strategy: str = 'sum'): """ Compute angle loss @@ -148,7 +148,13 @@ def angle_loss(y_true: torch.Tensor, y_pred: torch.Tensor, tau: float = 1.0): im /= (dz / dw) y_pred = torch.concat((re, im), dim=1) - y_pred = torch.abs(torch.sum(y_pred, dim=1)) * tau # absolute delta angle + if pooling_strategy == 'sum': + pooling = torch.sum(y_pred, dim=1) + elif pooling_strategy == 'mean': + pooling = torch.mean(y_pred, dim=1) + else: + raise ValueError(f'Unsupported pooling strategy: {pooling_strategy}') + y_pred = torch.abs(pooling) * tau # absolute delta angle y_pred = y_pred[:, None] - y_pred[None, :] y_pred = (y_pred - (1 - y_true) * 1e12).view(-1) zero = torch.Tensor([0]).to(y_pred.device) @@ -283,8 +289,7 @@ def get_pooling(outputs: torch.Tensor, inputs: Dict, pooling_strategy: str, padding_strategy: str = 'right') -> torch.Tensor: - """ - get pooling + """ Pooling the model outputs. :param outputs: torch.Tensor. Model outputs (without pooling) :param inputs: Dict. Model inputs @@ -322,19 +327,6 @@ def get_pooling(outputs: torch.Tensor, return outputs -def get_geometric_hidden_sizes(base: int = 8, max_hidden: int = 768) -> List[int]: - """ - get geometric hidden size series list given a hidden size range - - """ - lst = [] - s = base - while s < max_hidden: - lst.append(s) - s *= 2 - return lst - - class Prompts: """ Predefined prompts. Follow the model usage to choose the corresponding prompt. @@ -346,8 +338,7 @@ class Prompts: # list all pre-defined prompts print(Prompts.list_prompts()) # set prompt - angle.set_prompt(prompt=Prompts.A) - + angle.encode(*, prompt=Prompts.A) """ A = 'Summarize sentence "{text}" in one word:"' @@ -590,6 +581,12 @@ class AngleDataCollator: filter_duplicate: bool = True def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]: + """ Collate function for AngleDataTokenizer. + + :param features: List[Dict]. Tokenized data + :param return_tensors: str. Default "pt" + :return: Dict[str, torch.Tensor]. Collated data + """ if return_tensors is None: return_tensors = self.return_tensors has_token_type_ids = "token_type_ids" in features[0] @@ -695,30 +692,28 @@ class Pooler: def __init__(self, model: PreTrainedModel, pooling_strategy: Optional[Union[int, str]] = None, - padding_strategy: Optional[str] = None, - is_llm: bool = False): + padding_strategy: Optional[str] = None): self.model = model self.pooling_strategy = pooling_strategy self.padding_strategy = padding_strategy - self.is_llm = is_llm - def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None, + def __call__(self, + inputs: Dict, + layer_index: int = -1, + embedding_start: int = 0, + embedding_size: Optional[int] = None, return_all_layer_outputs: bool = False) -> torch.Tensor: - """ + """ Get sentence embeddings. :param inputs: Dict. Model inputs. :param layer_index: int. Get embeddings from specific layer. - :param embedding_size: int. Set embedding size for sentence embeddings for 2DMSE models. + :param embedding_size: int. Set embedding size for sentence embeddings for Espresso models. """ all_layer_outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states if return_all_layer_outputs: return all_layer_outputs outputs = all_layer_outputs[layer_index] - if self.is_llm: - batch_size = inputs['input_ids'].shape[0] - sequence_lengths = -1 if self.padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1 - outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths] - else: - outputs = get_pooling(outputs, inputs, self.pooling_strategy, padding_strategy=self.padding_strategy) + outputs = get_pooling(outputs, inputs, self.pooling_strategy, padding_strategy=self.padding_strategy) + outputs = outputs[:, embedding_start:] if embedding_size is not None: # topk embedding size return outputs[:, :embedding_size] @@ -732,30 +727,32 @@ class AngleTrainer(Trainer): :param pooler: Pooler. Required :param loss_kwargs: Optional[Dict]. Default None. :param dataset_format: str. Default DatasetFormats.A - :param fixed_teacher_name_or_path: Optional[str]. For distribution alignment. + :param teacher_name_or_path: Optional[str]. For distribution alignment. :param **kwargs: other parameters of Trainer. """ def __init__(self, pooler: Pooler, loss_kwargs: Optional[Dict] = None, dataset_format: str = DatasetFormats.A, - fixed_teacher_name_or_path: Optional[str] = None, - alignment_pooling_strategy: str = 'cls', + teacher_name_or_path: Optional[str] = None, + teacher_pooling_strategy: str = 'cls', **kwargs): super().__init__(**kwargs) self.pooler = pooler if loss_kwargs is None: loss_kwargs = {} self.loss_fct = AngleLoss(dataset_format=dataset_format, **loss_kwargs) - self.fixed_teacher_name_or_path = fixed_teacher_name_or_path - self.alignment_pooling_strategy = alignment_pooling_strategy - if fixed_teacher_name_or_path is not None: - assert not check_llm(fixed_teacher_name_or_path), ('Currently not support LLMs alignment,' - f' teacher={fixed_teacher_name_or_path}') - assert self.pooler.pooling_strategy == 'all', ('fixed_teacher_name_or_path detected!' + self.teacher_name_or_path = teacher_name_or_path + self.teacher_pooling_strategy = teacher_pooling_strategy + if teacher_name_or_path is not None: + logger.info('fixed teacher detected! ' + 'please ensure the fixed teacher has the same tokenizer as the backbone model!') + assert not check_llm(teacher_name_or_path), ('Currently not support LLMs alignment,' + f' teacher={teacher_name_or_path}') + assert self.pooler.pooling_strategy == 'all', ('teacher_name_or_path detected!' ' please set --pooling_strategy all') fixed_teacher_backbone = AutoModel.from_pretrained( - fixed_teacher_name_or_path, + teacher_name_or_path, trust_remote_code=True, torch_dtype="auto") @@ -763,27 +760,49 @@ def __init__(self, self.fixed_teacher_pooler = Pooler( fixed_teacher_backbone, pooling_strategy='all', - padding_strategy=self.pooler.padding_strategy, - is_llm=False) - self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean') - logger.info(f'Train with alignment, teacher={fixed_teacher_name_or_path}') + padding_strategy=self.pooler.padding_strategy) + logger.info(f'Train with alignment, teacher={teacher_name_or_path}') + + def distillation_loss(self, + inputs: torch.Tensor, + targets: torch.Tensor, + kl_temperature: float = 1.0) -> torch.Tensor: + """ Compute distillation loss. + + :param inputs: torch.Tensor. Input tensor. + :param targets: torch.Tensor. Target tensor. + :param kl_temperature: float. KL temperature. Default 1.0. + :return: torch.Tensor. Distillation loss. + """ + loss = 0. + loss += nn.MSELoss()(inputs, targets) + if kl_temperature > 0: + loss += nn.KLDivLoss(reduction='batchmean')( + F.log_softmax(inputs / kl_temperature, dim=-1), + F.softmax(targets / kl_temperature, dim=-1) + ) * kl_temperature + return loss def compute_loss(self, model, inputs, return_outputs=False): + """ Compute loss for AnglE. + + :param model: Huggingface model. + :param inputs: Dict. Model inputs. + :param return_outputs: bool. Return outputs or not. Default False. + :return: torch.Tensor. Loss. + """ labels = inputs.pop("labels", None) - if self.fixed_teacher_name_or_path is not None: + if self.teacher_name_or_path is not None: all_outputs = self.pooler(inputs) outputs = get_pooling(all_outputs, inputs, - self.alignment_pooling_strategy, + self.teacher_pooling_strategy, self.pooler.padding_strategy) loss = self.loss_fct(labels, outputs) with torch.no_grad(): self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device) all_fixed_outputs = self.fixed_teacher_pooler(inputs) - alignment_loss = self.kl_loss_fct( - F.log_softmax(all_outputs, dim=-1), - F.softmax(all_fixed_outputs, dim=-1) - ) + alignment_loss = self.distillation_loss(all_outputs, all_fixed_outputs) loss += alignment_loss else: outputs = self.pooler(inputs) @@ -792,14 +811,14 @@ def compute_loss(self, model, inputs, return_outputs=False): return (loss, outputs) if return_outputs else loss -class AngleTDMSETrainer(AngleTrainer): +class AngleESETrainer(AngleTrainer): """ - Custom Huggingface Trainer for AnglE 2DMSE. + Custom Huggingface Trainer for AnglE Espresso. :param pooler: Pooler. Required :param loss_kwargs: Optional[Dict]. Default None. :param dataset_format: str. Default DatasetFormats.A - :param fixed_teacher_name_or_path: Optional[str]. For distribution alignment. + :param teacher_name_or_path: Optional[str]. For distribution alignment. :param **kwargs: other parameters of Trainer. """ @@ -807,32 +826,83 @@ def __init__(self, pooler: Pooler, loss_kwargs: Optional[Dict] = None, dataset_format: str = DatasetFormats.A, - fixed_teacher_name_or_path: Optional[str] = None, - tdmse_kl_temperature: float = 1.0, - tdmse_teacher_lambda: float = 1.0, - tdmse_student_lambda: float = 1.0, - apply_tdmse_kl: bool = True, + teacher_name_or_path: Optional[str] = None, + ese_kl_temperature: float = 1.0, + ese_compression_size: int = 128, + apply_ese_pca: bool = True, **kwargs): super().__init__(pooler=pooler, loss_kwargs=loss_kwargs, dataset_format=dataset_format, - fixed_teacher_name_or_path=fixed_teacher_name_or_path, + teacher_name_or_path=teacher_name_or_path, **kwargs) - self.tdmse_kl_temperature = tdmse_kl_temperature - self.tdmse_teacher_lambda = tdmse_teacher_lambda - self.tdmse_student_lambda = tdmse_student_lambda - self.apply_tdmse_kl = apply_tdmse_kl + self.ese_kl_temperature = ese_kl_temperature + self.ese_compression_size = ese_compression_size + self.apply_ese_pca = apply_ese_pca self.n_layers = self.pooler.model.config.num_hidden_layers - self.hidden_size = self.pooler.model.config.hidden_size - self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.hidden_size) - self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean') - logger.info('Train with 2DMSE!') + logger.info('Train with Espresso v5!') + + @torch.no_grad() + def pca_compress(self, m: torch.Tensor, k: int) -> torch.Tensor: + """ Get topk feature via quasi-SVD. + :param m: torch.Tensor. Input tensor. + :param k: int. Top-k feature size. + :return: torch.Tensor. Top-k feature. + """ + A = F.softmax(m.T @ m / m.shape[-1]**0.5, dim=-1) + u, s, _ = torch.svd_lowrank(A, q=k) + # a = u @ torch.diag(F.softmax(s, dim=-1)) @ (v.T)[:, :k] + # top-k principal components + topk_deps = u @ torch.diag(s) + return m @ topk_deps + + @torch.no_grad() + def pca_compress_old(self, m: torch.Tensor, k: int) -> torch.Tensor: + """ Get topk feature via quasi-SVD. + :param m: torch.Tensor. Input tensor. + :param k: int. Top-k feature size. + :return: torch.Tensor. Top-k feature. + """ + u, s, _ = torch.svd_lowrank(m, q=k) + # top-k principal components + return u @ torch.diag(s) + + def compute_student_loss(self, + inputs: Dict, + all_layer_outputs: torch.Tensor, + labels: torch.Tensor, + pooling_strategy: str, + padding_strategy: str) -> torch.Tensor: + loss = 0. + compression_loss = 0. + for i in range(self.n_layers - 1): + division = (1. + math.log(1 + i)) + all_student_outputs = all_layer_outputs[i] + student_outputs = get_pooling(all_student_outputs, + inputs, + pooling_strategy, + padding_strategy) + + slimmed_outputs = student_outputs[:, :self.ese_compression_size] + loss += self.loss_fct(labels, slimmed_outputs) / division + if self.apply_ese_pca: + compression_loss += self.distillation_loss( + slimmed_outputs, + self.pca_compress(student_outputs, self.ese_compression_size), + kl_temperature=self.ese_kl_temperature + ) / division + return (loss + compression_loss) / (self.n_layers - 1) def compute_loss(self, model, inputs, return_outputs=False): + """ Compute loss for Espresso. + :param model: Huggingface model. + :param inputs: Dict. Model inputs. + :param return_outputs: bool. Return outputs or not. Default False. + :return: torch.Tensor. Loss. + """ labels = inputs.pop("labels", None) # layer - sample_layer = random.randint(1, self.n_layers - 1) - pooling_strategy = (self.alignment_pooling_strategy + pooling_strategy = (self.teacher_pooling_strategy if self.pooler.pooling_strategy == 'all' else self.pooler.pooling_strategy) all_layer_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True) @@ -840,51 +910,36 @@ def compute_loss(self, model, inputs, return_outputs=False): teacher_outputs = get_pooling(all_teacher_outputs, inputs, pooling_strategy, self.pooler.padding_strategy) - all_student_outputs = all_layer_outputs[sample_layer] - student_outputs = get_pooling(all_student_outputs, - inputs, - pooling_strategy, - self.pooler.padding_strategy) - teacher_kl_outputs = teacher_outputs - if self.fixed_teacher_name_or_path is not None: + loss = self.loss_fct(labels, teacher_outputs) + + slimmed_outputs = teacher_outputs[:, :self.ese_compression_size] + loss += self.loss_fct(labels, slimmed_outputs) + if self.apply_ese_pca: + loss += self.distillation_loss( + slimmed_outputs, + self.pca_compress(teacher_outputs, self.ese_compression_size), + kl_temperature=self.ese_kl_temperature + ) + + # student loss + loss += self.compute_student_loss( + inputs, + all_layer_outputs, + labels, + pooling_strategy, + self.pooler.padding_strategy, + ) + + # alignment loss + if self.teacher_name_or_path is not None: with torch.no_grad(): self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device) all_fixed_outputs = self.fixed_teacher_pooler(inputs) - teacher_kl_outputs = get_pooling(all_fixed_outputs, - inputs, - self.alignment_pooling_strategy, - self.pooler.padding_strategy) - - teacher_loss = self.loss_fct(labels, teacher_outputs) - loss1 = teacher_loss - student_loss = self.loss_fct(labels, student_outputs) - loss1 += student_loss / sample_layer - if self.apply_tdmse_kl and self.tdmse_student_lambda > 0: - kl_loss = self.kl_loss_fct( - F.log_softmax(student_outputs / self.tdmse_kl_temperature, dim=-1), - F.softmax(teacher_kl_outputs / self.tdmse_kl_temperature, dim=-1) - ) * self.tdmse_kl_temperature - loss1 += kl_loss - - # feature - hidden_size = random.choice(self.tdmse_hidden_sizes) - slimmed_teacher_outputs = teacher_outputs[:, :hidden_size] - slimmed_student_outputs = student_outputs[:, :hidden_size] - - slimmed_teacher_loss = self.loss_fct(labels, slimmed_teacher_outputs) - loss2 = slimmed_teacher_loss - slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs) - loss2 += slimmed_student_loss / sample_layer - - loss = loss1 + loss2 - - if self.fixed_teacher_name_or_path is not None: - alignment_loss = self.kl_loss_fct( - F.log_softmax(all_teacher_outputs, dim=-1), - F.softmax(all_fixed_outputs, dim=-1) - ) - loss += alignment_loss + alignment_loss = self.distillation_loss( + all_teacher_outputs, all_fixed_outputs, + ) + loss += alignment_loss return (loss, teacher_outputs) if return_outputs else loss @@ -892,42 +947,55 @@ class AngleLoss: """ Configure AngleLoss. - :param w1: float. weight for cosine_loss. Default 1.0 - :param w2: float. weight for contrastive loss. Default 1.0 - :param w3: float. weight for angle loss. Default 1.0 + :param cosine_w: float. weight for cosine_loss. Default 1.0 + :param ibn_w: float. weight for contrastive loss. Default 1.0 + :param angle_w: float. weight for angle loss. Default 1.0 :param cosine_tau: float. tau for cosine loss. Default 20.0 :param ibn_tau: float. tau for contrastive loss. Default 20.0 - :param angle_tau: float. tau for angle loss. Default 1.0 + :param angle_tau: float. tau for angle loss. Default 20.0 + :param angle_pooling_strategy: str. pooling strategy for angle loss. Default'sum'. :param dataset_format: Optional[str]. Default None. """ def __init__(self, - w1: float = 1.0, - w2: float = 1.0, - w3: float = 1.0, + cosine_w: float = 1.0, + ibn_w: float = 1.0, + angle_w: float = 1.0, cosine_tau: float = 20.0, ibn_tau: float = 20.0, - angle_tau: float = 1.0, + angle_tau: float = 20.0, + angle_pooling_strategy: str = 'sum', dataset_format: Optional[str] = None, **kwargs): - self.w1 = w1 - self.w2 = w2 - self.w3 = w3 + if 'w1' in kwargs or 'w2' in kwargs or 'w3' in kwargs: + assert ('w1, w2, and w3 has been renamed to cosine_w, ibn_w, and angle_w, respecitvely.' + 'Please use new names instead.') + self.cosine_w = cosine_w + self.ibn_w = ibn_w + self.angle_w = angle_w self.cosine_tau = cosine_tau self.ibn_tau = ibn_tau self.angle_tau = angle_tau + self.angle_pooling_strategy = angle_pooling_strategy self.dataset_format = dataset_format def __call__(self, labels: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor: + """ Compute loss for AnglE. + + :param labels: torch.Tensor. Labels. + :param outputs: torch.Tensor. Outputs. + :return: torch.Tensor. Loss. + """ if self.dataset_format == DatasetFormats.A: loss = 0. - if self.w1 > 0: - loss += self.w1 * cosine_loss(labels, outputs, self.cosine_tau) - if self.w2 > 0: - loss += self.w2 * in_batch_negative_loss(labels, outputs, self.ibn_tau) - if self.w3 > 0: - loss += self.w3 * angle_loss(labels, outputs, self.angle_tau) + if self.cosine_w > 0: + loss += self.cosine_w * cosine_loss(labels, outputs, self.cosine_tau) + if self.ibn_w > 0: + loss += self.ibn_w * in_batch_negative_loss(labels, outputs, self.ibn_tau) + if self.angle_w > 0: + loss += self.angle_w * angle_loss(labels, outputs, self.angle_tau, + pooling_strategy=self.angle_pooling_strategy) elif self.dataset_format == DatasetFormats.B: # text,positive,negative text = outputs[::3] @@ -944,12 +1012,13 @@ def __call__(self, combined_labels = torch.cat((positive_labels, negative_labels), dim=0) loss = 0. - if self.w1 > 0: - loss += self.w1 * cosine_loss(combined_labels, combined_inputs, self.cosine_tau) - if self.w2 > 0: - loss += self.w2 * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau) - if self.w3 > 0: - loss += self.w3 * angle_loss(combined_labels, combined_inputs, self.angle_tau) + if self.cosine_w > 0: + loss += self.cosine_w * cosine_loss(combined_labels, combined_inputs, self.cosine_tau) + if self.ibn_w > 0: + loss += self.ibn_w * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau) + if self.angle_w > 0: + loss += self.angle_w * angle_loss(combined_labels, combined_inputs, self.angle_tau, + pooling_strategy=self.angle_pooling_strategy) elif self.dataset_format == DatasetFormats.C: text = outputs[::2] positive = outputs[1::2] @@ -1014,6 +1083,7 @@ class AnglE: :param device: Optional[str]. Specify device. Default None. :param kbit_kwargs: Optional[Dict]. kwargs for kbit. Default None. details refer to: https://huggingface.co/docs/peft/package_reference/peft_model#peft.prepare_model_for_kbit_training + :param tokenizer_padding_side: Optional[str]. Specify tokenizer padding side from [`left`, `right`]. Default None. :param **kwargs: Any. """ # NOQA cfg_file_name = 'angle.config' @@ -1034,6 +1104,9 @@ def __init__(self, torch_dtype: Optional[torch.dtype] = None, device: Optional[str] = None, kbit_kwargs: Optional[Dict] = None, + tokenizer_padding_side: Optional[str] = None, + apply_billm: bool = False, + billm_model_class: Optional[str] = None, **kwargs: Any): super().__init__() self.max_length = max_length @@ -1050,6 +1123,10 @@ def __init__(self, if self.is_llm: logger.info('LLM detected, automatically set is_llm=True.' 'If it is wrong, you can manually set `is_llm`.') + if self.is_llm and self.pooling_strategy != 'last': + logger.info(f'๐Ÿšจ LLM detected, but pooling strategy is specified to {self.pooling_strategy}.' + 'Please check whether it is correct. It is recommended to use `last` pooling strategy for LLM.') + self.apply_lora = apply_lora if self.apply_lora is None: if self.is_llm: @@ -1063,12 +1140,6 @@ def __init__(self, else: self.gpu_count = 0 - self.prompt = None - if self.is_llm: - logger.info('LLM detected, automatically set prompt. ' - 'You can change this setting by manually configuring the `set_prompt()` function.') - self.set_prompt() - self.apply_bfloat16 = apply_bfloat16 if self.apply_bfloat16 is None and 'llama' in model_name_or_path.lower(): logger.info('LLaMA detected, automatically set `apply_bfloat16=True`. ' @@ -1092,6 +1163,8 @@ def __init__(self, logger.info(f'lora_config={lora_config}') self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + if tokenizer_padding_side is not None and self.tokenizer.padding_side != tokenizer_padding_side: + self.tokenizer.padding_side = tokenizer_padding_side if self.is_llm and self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = 0 @@ -1099,7 +1172,18 @@ def __init__(self, kbit_kwargs = kbit_kwargs if kbit_kwargs is not None else {} if self.is_llm: device_map = "auto" - MODEL_CLASS = AutoModelForCausalLM + if apply_billm: + assert billm_model_class is not None, "billm_model_class should be specified for apply_billm=True" + try: + import billm + except ImportError as err: + print(f'Import Error: {err}') + print('Please install the latest billm via: python -m pip install -U billm') + raise + + MODEL_CLASS = getattr(billm, billm_model_class) + else: + MODEL_CLASS = AutoModelForCausalLM if train_mode and self.gpu_count > 1: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} # LLM @@ -1137,7 +1221,8 @@ def __init__(self, ) elif train_mode: if 'target_modules' not in lora_config or lora_config.get('target_modules', None) is None: - target_modules = find_all_linear_names(model, linear_type=bnb.nn.Linear4bit) + target_modules = find_all_linear_names( + model, linear_type=bnb.nn.Linear4bit if load_kbit == 4 else nn.Linear) lora_config['target_modules'] = target_modules logger.info(f'lora target modules={target_modules}') peft_config = LoraConfig(**lora_config) @@ -1145,46 +1230,35 @@ def __init__(self, model = AnglE.kbit_post_handle(model) self.backbone = model else: - if train_mode: - model = MODEL_CLASS.from_pretrained( - model_name_or_path, + if self.apply_bfloat16: + model = MODEL_CLASS.from_pretrained(model_name_or_path, + output_hidden_states=True, + trust_remote_code=True).bfloat16() + else: + model = MODEL_CLASS.from_pretrained(model_name_or_path, + device_map=device_map, + output_hidden_states=True, + trust_remote_code=True, + torch_dtype=torch_dtype or torch.float16) + + if 'target_modules' not in lora_config or lora_config.get('target_modules', None) is None: + target_modules = find_all_linear_names(model) + lora_config['target_modules'] = target_modules + logger.info(f'lora target modules={target_modules}') + if pretrained_lora_path is not None: + print(f'Load lora weight from {pretrained_lora_path}') + model = PeftModel.from_pretrained( + model, + pretrained_lora_path, torch_dtype=torch.float16 if load_kbit == 16 else torch.float32, device_map=device_map, - trust_remote_code=True, + is_trainable=train_mode ) - if pretrained_lora_path is not None: - print(f'Load lora weight from {pretrained_lora_path}') - model = PeftModel.from_pretrained( - model, - pretrained_lora_path, - torch_dtype=torch.float16 if load_kbit == 16 else torch.float32, - device_map=device_map, - is_trainable=train_mode - ) - else: + else: + if train_mode: peft_config = LoraConfig(**lora_config) model = get_peft_model(model, peft_config) - else: - if self.apply_bfloat16: - model = MODEL_CLASS.from_pretrained(model_name_or_path, - output_hidden_states=True, - trust_remote_code=True).bfloat16() - else: - model = MODEL_CLASS.from_pretrained(model_name_or_path, - device_map=device_map, - output_hidden_states=True, - trust_remote_code=True, - load_in_8bit=load_kbit == 8, - torch_dtype=torch_dtype or torch.float16) - if pretrained_lora_path is not None: - logger.info(f'Load lora weight from {pretrained_lora_path}') - model = PeftModel.from_pretrained( - model, - pretrained_lora_path, - torch_dtype=torch_dtype or torch.float16, - device_map=device_map, - is_trainable=train_mode - ) + self.backbone = model else: if self.apply_bfloat16: @@ -1196,7 +1270,6 @@ def __init__(self, device_map=device_map, output_hidden_states=True, trust_remote_code=True, - load_in_8bit=load_kbit == 8, torch_dtype=torch_dtype or torch.float16) self.backbone = model else: @@ -1232,10 +1305,9 @@ def __init__(self, self.pooler = Pooler( self.backbone, pooling_strategy=self.pooling_strategy, - padding_strategy=self.tokenizer.padding_side, - is_llm=self.is_llm) + padding_strategy=self.tokenizer.padding_side) - # full_backbone is used to 2DMSE inference + # full_backbone is used to Espresso inference self.full_backbone = None self.__cfg = { 'model_name_or_path': model_name_or_path, @@ -1243,7 +1315,11 @@ def __init__(self, 'model_kwargs': model_kwargs, 'pooling_strategy': pooling_strategy, 'lora_config_kwargs': lora_config, - 'apply_lora': apply_lora, + 'is_llm': self.is_llm, + 'apply_billm': apply_billm, + 'billm_model_class': billm_model_class, + 'apply_lora': self.apply_lora, + 'tokenizer_padding_side': tokenizer_padding_side, } self.__cfg.update(kwargs) @@ -1355,8 +1431,11 @@ def fit(self, argument_kwargs: Optional[Dict] = None, trainer_kwargs: Optional[Dict] = None, loss_kwargs: Optional[Dict] = None, - apply_tdmse: bool = False, - filter_duplicate: bool = True): + apply_ese: bool = False, + filter_duplicate: bool = True, + push_to_hub: bool = False, + hub_model_id: Optional[str] = None, + hub_private_repo: bool = True): """ Fit using AnglE. @@ -1378,12 +1457,18 @@ def fit(self, refer to: https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments :param trainer_kwargs: Optional[Dict]. kwargs for AngleTrainer. :param loss_kwargs: Optional[Dict]. kwargs for AngleLoss. - :param apply_tdmse: bool, whether apply TDMSE training. + :param apply_ese: bool, whether apply ESE training. + :param filter_duplicate: bool, whether filter duplicate samples. + :param push_to_hub: bool, whether push to hub. + :param hub_model_id: Optional[str], hub model id. + :param hub_private_repo: bool, whether push to private repo. """ # NOQA if output_dir is not None: os.makedirs(output_dir, exist_ok=True) # save config self.save_config(os.path.join(output_dir, AnglE.cfg_file_name)) + # save tokenizer + self.tokenizer.save_pretrained(output_dir) if self.gpu_count > 1: gradient_accumulation_steps = gradient_accumulation_steps // self.gpu_count @@ -1391,12 +1476,27 @@ def fit(self, fp16 = True else: fp16 = False + + # init argument_kwargs if argument_kwargs is None: argument_kwargs = {} + if 'push_to_hub' not in argument_kwargs: + argument_kwargs['push_to_hub'] = push_to_hub + if 'hub_model_id' not in argument_kwargs: + argument_kwargs['hub_model_id'] = hub_model_id + if 'hub_private_repo' not in argument_kwargs: + argument_kwargs['hub_private_repo'] = hub_private_repo + if trainer_kwargs is None: trainer_kwargs = {} + callbacks = None if valid_ds is not None: + # check format + for obj in valid_ds: + if obj['extra']['dataset_format'] != DatasetFormats.A: + raise ValueError('Currently only support evaluation for DatasetFormats.A.') + break best_ckpt_dir = None if output_dir is not None: best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint') @@ -1405,7 +1505,7 @@ def fit(self, save_dir=best_ckpt_dir) callbacks = [evaluate_callback] - CustomTrainer = AngleTDMSETrainer if apply_tdmse else AngleTrainer + CustomTrainer = AngleESETrainer if apply_ese else AngleTrainer trainer = CustomTrainer( pooler=self.pooler, model=self.backbone, @@ -1442,6 +1542,8 @@ def fit(self, self.backbone = torch.compile(self.backbone) trainer.train() + if argument_kwargs.get('push_to_hub', False): + trainer.push_to_hub() self.backbone.save_pretrained(output_dir) def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[float] = None, device: Any = None): @@ -1473,29 +1575,28 @@ def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[floa accuracy = np.mean((y_trues > 0.5) == (y_preds > threshold)) return corrcoef, accuracy - def set_prompt(self, prompt: str = Prompts.A): - self.prompt = prompt - if self.prompt is not None: - logger.info('Prompt is set, the prompt will be automatically applied during the encoding phase. ' - 'To disable prompt setting, please configure set_prompt(prompt=None)') - def encode(self, inputs: Union[List[str], Tuple[str], List[Dict], str], max_length: Optional[int] = None, end_with_eos: bool = False, to_numpy: bool = True, layer_index: int = -1, + embedding_start: int = 0, embedding_size: Optional[int] = None, - device: Optional[Any] = None): + device: Optional[Any] = None, + prompt: Optional[str] = None): """ encode texts. :param inputs: Union[List[str], Tuple[str], List[Dict], str]. Input texts. Required. :param max_length: Optional[int]. Default None. :param to_numpy: bool. Default True. - :param layer_index: int. Obtain specific layer's sentence embeddings (for 2DMSE). - :param embedding_size: Optional[int]. Specify embedding size (for 2DMSE). + :param layer_index: int. Obtain specific layer's sentence embeddings (for Espresso). + :param embedding_start: int. Specify the start position of the embedding (for Espresso). + :param embedding_size: Optional[int]. Specify embedding size (for Espresso). + The embeddings from embedding_start to embedding_start+embedding_size will be returned. :param device: Optional[Any]. Default None. + :param prompt: Optional[str]. Default None. """ if layer_index != -1 and self.full_backbone is None: self.full_backbone = copy.deepcopy(self.backbone) @@ -1508,10 +1609,10 @@ def encode(self, self.backbone.eval() if not isinstance(inputs, (tuple, list)): inputs = [inputs] - if self.prompt is not None: + if prompt is not None: for i, obj in enumerate(inputs): assert isinstance(obj, dict), 'The prompt has been set, please pass a dict like {"prompt_key": "text"}' - inputs[i] = self.prompt.format(**obj) + inputs[i] = prompt.format(**obj) max_length = max_length or self.max_length if end_with_eos: max_length -= 1 @@ -1534,10 +1635,20 @@ def encode(self, return_tensors='pt') tok.to(device) with torch.no_grad(): - output = self.pooler(tok, layer_index=layer_index, embedding_size=embedding_size) + output = self.pooler(tok, + layer_index=layer_index, + embedding_start=embedding_start, + embedding_size=embedding_size) if to_numpy: return output.float().detach().cpu().numpy() return output - def export_onnx(self): - pass + def push_to_hub(self, hub_model_id: str, private: bool = True, **kwargs): + """ push model to hub + + :param hub_model_id: str, hub model id. + :param private: bool, whether push to private repo. Default True. + :param kwargs: other kwargs for `push_to_hub` method. + """ + self.tokenizer.push_to_hub(hub_model_id, private=private, **kwargs) + self.backbone.push_to_hub(hub_model_id, private=private, **kwargs) diff --git a/angle_emb/train_cli.py b/angle_emb/train_cli.py index 6d66ddb..99d7611 100644 --- a/angle_emb/train_cli.py +++ b/angle_emb/train_cli.py @@ -29,6 +29,8 @@ help='Specify huggingface datasets name or local file path for valid set.') parser.add_argument('--valid_subset_name', type=str, default=None, help='Specify huggingface datasets subset name for valid set') +parser.add_argument('--valid_split_name', type=str, default='train', + help='Specify huggingface datasets split name for valid set') parser.add_argument('--prompt_template', type=str, default=None, help='Specify prompt_template like "Instruct: xxx\nInput: {text}", default None') parser.add_argument('--save_dir', type=str, default=None, @@ -39,20 +41,18 @@ help='Specify dataset random seed, default None') parser.add_argument('--workers', type=int, default=2, help='Specify dataset workers, default 2') -parser.add_argument('--w1', type=float, default=1.0, - help='Specify w1 (cosine), default 1.0') -parser.add_argument('--w2', type=float, default=1.0, - help='Specify w2 (ibn), default 1.0') -parser.add_argument('--w3', type=float, default=1.0, - help='Specify w3 (angle), default 1.0') +parser.add_argument('--cosine_w', type=float, default=1.0, + help='Specify weight for cosine loss, default 1.0') +parser.add_argument('--ibn_w', type=float, default=1.0, + help='Specify weight for ibn loss, default 1.0') +parser.add_argument('--angle_w', type=float, default=1.0, + help='Specify weight for angle loss, default 1.0') parser.add_argument('--angle_tau', type=float, default=20.0, help='Specify angle_tau, default 20.0') parser.add_argument('--cosine_tau', type=float, default=20.0, help='Specify cosine_tau, defaut 20.0') parser.add_argument('--ibn_tau', type=float, default=20.0, help='Specify ibn_tau, defaut 20.0') -parser.add_argument('--is_llm', type=int, default=0, choices=[0, 1], - help='Specify is_llm, choices [0, 1], defaut 0') parser.add_argument('--apply_lora', type=int, default=0, choices=[0, 1], help='Specify apply_lora, choices [0, 1], defaut 0') parser.add_argument('--load_kbit', type=int, default=None, choices=[4, 8, 16], @@ -63,16 +63,18 @@ help='Specify lora_alpha, defaut 32') parser.add_argument('--lora_dropout', type=float, default=0.1, help='Specify lora_dropout, defaut 0.1') +parser.add_argument('--lora_target_modules', type=str, default=None, + help='Specify lora_target_modules. comma serves as the splitter, such as W,b. Defaut None') parser.add_argument('--learning_rate', type=float, default=1e-5, help='Specify learning_rate, defaut 1e-5') -parser.add_argument('--start_bilayer_index', type=int, default=None, - help='Specify start_bilayer_index, defaut None') parser.add_argument('--warmup_steps', type=int, default=100, help='Specify warmup_steps, defaut 100') parser.add_argument('--logging_steps', type=int, default=100, help='Specify logging_steps, defaut 100') parser.add_argument('--pooling_strategy', type=str, default='cls', help='Specify pooling_strategy from [`cls`, `last`, `avg`, `cls_avg`, `max`], default `cls`') +parser.add_argument('--tokenizer_padding_side', type=str, default=None, choices=['left', 'right'], + help='specify tokenizer padding side from [`left`, `right`], default None') parser.add_argument('--epochs', type=int, default=20, help='Specify epochs, default 20') parser.add_argument('--max_steps', type=int, default=-1, help='Specify max steps, default -1 (Automatically calculated from epochs)') @@ -83,27 +85,34 @@ help='Flag to enable streaming mode (default: False)') parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Specify gradient_accumulation_steps, default 1') -parser.add_argument('--torch_dtype', type=str, default='float32', - help='Specify torch_dtype from `auto`, `float32`, `float16`, default float32') +parser.add_argument('--torch_dtype', type=str, default=None, choices=['auto', 'float32', 'float16', 'bfloat16'], + help='Specify torch_dtype from [`auto`, `float32`, `float16`, `bfloat16`], default None') parser.add_argument('--fp16', type=bool, default=None, choices=[0, 1], help='Specify fp16, choices [0, 1], default None') parser.add_argument('--push_to_hub', type=int, default=0, choices=[0, 1], help='Specify push_to_hub, default 0') +parser.add_argument('--hub_private_repo', type=int, default=1, choices=[0, 1], + help='Specify hub_private_repo, default 1') parser.add_argument('--hub_model_id', type=str, default=None, - help='Specify push_to_hub_model_id, default None, format like organization/model_id') -# configure TDMSE -parser.add_argument('--apply_tdmse', type=int, default=0, choices=[0, 1], - help='Specify apply_tdmse to support 2DMSE training, default 0') -parser.add_argument('--apply_tdmse_kl', type=int, default=1, choices=[0, 1], - help='Specify apply_tdmse_kl to support 2DMSE training with KL Divergence, default 1') -parser.add_argument('--tdmse_kl_temperature', type=float, default=1.0, - help='Specify KL temperature for tdmse, default 1.0') -parser.add_argument('--tdmse_teacher_lambda', type=float, default=1.0, - help='Specify teacher lambda for tdmse, default 1.0') -parser.add_argument('--tdmse_student_lambda', type=float, default=1.0, - help='Specify student lambda for tdmse, default 1.0') + help='Specify hub_model_id, default None, format like organization/model_id') +# configure LLM +parser.add_argument('--is_llm', type=int, default=0, choices=[0, 1], + help='Specify is_llm, choices [0, 1], defaut 0') +parser.add_argument('--apply_billm', type=int, default=0, choices=[0, 1], + help='Specify apply_billm, choices [0, 1], defaut 0') +parser.add_argument('--billm_model_class', type=str, default=None, + help='Specify billm model class name, default None') +# configure ESE +parser.add_argument('--apply_ese', type=int, default=0, choices=[0, 1], + help='Specify apply_ese to support Espresso Sentence Embedding training, default 0') +parser.add_argument('--ese_kl_temperature', type=float, default=1.0, + help='Specify KL temperature for ese, default 1.0') +parser.add_argument('--ese_compression_size', type=int, default=128, + help='Specify compression size for ese, default 128') # configure teacher alignment -parser.add_argument('--fixed_teacher_name_or_path', type=str, default=None, +parser.add_argument('--teacher_name_or_path', type=str, default=None, help='Specify model_name_or_path for teacher alignment, default None') +parser.add_argument('--teacher_pooling_strategy', type=str, default='cls', + help='Specify pooling strategy for teacher from [`cls`, `last`, `avg`, `cls_avg`, `max`], default `cls`') # NOQA # configure wandb parser.add_argument('--wandb_project', type=str, default=None, help='Specify WANDB_PROJECT, default None') parser.add_argument('--wandb_log_model', type=str, default=None, help='Specify WANDB_LOG_MODEL, default None') @@ -129,6 +138,20 @@ args.torch_dtype = torch.float32 elif args.torch_dtype == 'float16': args.torch_dtype = torch.float16 +elif args.torch_dtype == 'bfloat16': + args.torch_dtype = torch.bfloat16 + +apply_bfloat16 = None +if args.torch_dtype == torch.bfloat16: + apply_bfloat16 = True + +lora_config = { + 'r': args.lora_r, + 'lora_alpha': args.lora_alpha, + 'lora_dropout': args.lora_dropout, +} +if args.lora_target_modules is not None: + lora_config['target_modules'] = [v.strip() for v in args.lora_target_modules.split(',') if v.strip()] def main(): @@ -138,30 +161,34 @@ def main(): pretrained_lora_path=args.pretrained_lora_path, pooling_strategy=args.pooling_strategy, train_mode=True, - is_llm=args.is_llm, apply_lora=args.apply_lora, - lora_config_kwargs={ - 'r': args.lora_r, - 'lora_alpha': args.lora_alpha, - 'lora_dropout': args.lora_dropout, - }, + lora_config_kwargs=lora_config, load_kbit=args.load_kbit, - torch_dtype=args.torch_dtype) - - if args.start_bilayer_index is not None: - model.backbone.set_start_bilayer_index(args.start_bilayer_index) + torch_dtype=args.torch_dtype, + apply_bfloat16=apply_bfloat16, + tokenizer_padding_side=args.tokenizer_padding_side, + is_llm=args.is_llm, + apply_billm=args.apply_billm, + billm_model_class=args.billm_model_class) if os.path.exists(args.train_name_or_path): - ds = load_dataset('json', data_files=[args.train_name_or_path], streaming=args.streaming) + ds = load_dataset('json', + data_files=[args.train_name_or_path], + num_proc=args.workers, + streaming=args.streaming) else: - ds = load_dataset(args.train_name_or_path, args.train_subset_name, streaming=args.streaming) + ds = load_dataset(args.train_name_or_path, + args.train_subset_name, + num_proc=args.workers, + streaming=args.streaming) logger.info('Dataset overview:') print(ds) logger.info('Processing train...') if args.streaming: train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map( - AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template)) + AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + num_proc=args.workers) else: train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), @@ -171,23 +198,21 @@ def main(): if valid_ds is None and args.valid_name_or_path is not None: logger.info('Validation detected, processing validation...') if os.path.exists(args.valid_name_or_path): - valid_ds = load_dataset('json', data_files=[args.valid_name_or_path]) + valid_ds = load_dataset('json', data_files=[args.valid_name_or_path], num_proc=args.workers) else: - valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name) - - if args.streaming: - valid_ds = valid_ds[args.valid_subset_name or 'train'].map( - AngleDataTokenizer(model.tokenizer, model.max_length, - prompt_template=args.prompt_template)) - else: - valid_ds = valid_ds[args.valid_subset_name or 'train'].map( - AngleDataTokenizer(model.tokenizer, model.max_length, - prompt_template=args.prompt_template), num_proc=args.workers) + if args.valid_subset_name is not None: + valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name, num_proc=args.workers) + else: + valid_ds = load_dataset(args.valid_name_or_path, num_proc=args.workers) + valid_ds = valid_ds[args.valid_split_name or 'train'].map( + AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + num_proc=args.workers) argument_kwargs = {} if args.push_to_hub: assert args.hub_model_id is not None, 'Please specify hub_mode_id via --hub_model_id xxx' - argument_kwargs['push_to_hub'] = True, + argument_kwargs['push_to_hub'] = True + argument_kwargs['hub_private_repo'] = bool(args.hub_private_repo) argument_kwargs['hub_model_id'] = args.hub_model_id if args.wandb_project is not None: argument_kwargs['report_to'] = 'wandb' @@ -195,17 +220,16 @@ def main(): argument_kwargs['max_steps'] = args.max_steps trainer_kwargs = None - if args.fixed_teacher_name_or_path is not None: + if args.teacher_name_or_path is not None: trainer_kwargs = { - 'fixed_teacher_name_or_path': args.fixed_teacher_name_or_path + 'teacher_name_or_path': args.teacher_name_or_path, + 'teacher_pooling_strategy': args.teacher_pooling_strategy, } - if args.apply_tdmse: + if args.apply_ese: trainer_kwargs = trainer_kwargs or {} trainer_kwargs = dict(trainer_kwargs, **{ - 'apply_tdmse_kl': args.apply_tdmse_kl, - 'tdmse_kl_temperature': args.tdmse_kl_temperature, - 'tdmse_teacher_lambda': args.tdmse_teacher_lambda, - 'tdmse_student_lambda': args.tdmse_student_lambda, + 'ese_kl_temperature': args.ese_kl_temperature, + 'ese_compression_size': args.ese_compression_size, }) model.fit( @@ -220,16 +244,16 @@ def main(): logging_steps=args.logging_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, loss_kwargs={ - 'w1': args.w1, - 'w2': args.w2, - 'w3': args.w3, + 'cosine_w': args.cosine_w, + 'ibn_w': args.ibn_w, + 'angle_w': args.angle_w, 'cosine_tau': args.cosine_tau, 'ibn_tau': args.ibn_tau, 'angle_tau': args.angle_tau, }, fp16=args.fp16, argument_kwargs=argument_kwargs, - apply_tdmse=args.apply_tdmse, + apply_ese=args.apply_ese, trainer_kwargs=trainer_kwargs, ) diff --git a/angle_emb/utils.py b/angle_emb/utils.py index 38e4b77..f4a202a 100644 --- a/angle_emb/utils.py +++ b/angle_emb/utils.py @@ -1,7 +1,20 @@ # -*- coding: utf-8 -*- import logging +from typing import List + +from scipy import spatial logging.basicConfig(level=logging.INFO) logger = logging.getLogger('AnglE') + + +def cosine_similarity(vec1: List[int], vec2: List[int]): + """ Calculate cosine similarity between two vectors. + + :param vec1: a list of integers + :param vec2: a list of integers + :return: a float value between 0 and 1, indicating the similarity between the two vectors. + """ + return 1 - spatial.distance.cosine(vec1, vec2) diff --git a/docs/notes/installation.rst b/docs/notes/installation.rst index b94f794..3f48a5f 100644 --- a/docs/notes/installation.rst +++ b/docs/notes/installation.rst @@ -8,7 +8,7 @@ You can install or upgrade AnglE from pip: .. code-block:: bash - python -m pip install -U angle_emb + python -m pip install -U angle-emb diff --git a/docs/notes/quick_start.rst b/docs/notes/quick_start.rst index 0f53ab5..5eceeab 100644 --- a/docs/notes/quick_start.rst +++ b/docs/notes/quick_start.rst @@ -7,7 +7,7 @@ A few lines of code to generate sentence embeddings using AnglE. .. code-block:: bash - python -m pip install -U angle_emb + python -m pip install -U angle-emb Other installation methods, please refer to the `Installation` section. @@ -19,4 +19,4 @@ Other installation methods, please refer to the `Installation` section. from angle_emb import AnglE angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').cuda() - vec = angle.encode("I'll take a break.") \ No newline at end of file + vec = angle.encode("I'll take a break.") diff --git a/scripts/convert_to_sentence_transformer.py b/scripts/convert_to_sentence_transformer.py new file mode 100644 index 0000000..02c0772 --- /dev/null +++ b/scripts/convert_to_sentence_transformer.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +""" This is a script to convert a pre-trained AnglE model to a SentenceTransformer model. +""" + +import argparse + +from sentence_transformers import models +from sentence_transformers import SentenceTransformer + + +parser = argparse.ArgumentParser() +parser.add_argument('--model_name_or_path', type=str, required=True, + help='Specify model_name_or_path to set transformer backbone, default roberta-large') +parser.add_argument('--pooling_strategy', type=str, required=True, + help='Specify pooling strategy') +parser.add_argument('--max_length', type=int, default=512, + help='Specify max length') +parser.add_argument('--push_to_hub', type=int, default=0, choices=[0, 1], help='Specify push_to_hub, default 0') +parser.add_argument('--hub_private_repo', type=int, default=1, choices=[0, 1], + help='Specify hub_private_repo, default 1') +parser.add_argument('--hub_model_id', type=str, default=None, + help='Specify push_to_hub_model_id, default None, format like organization/model_id') + +args = parser.parse_args() + +word_embedding_model = models.Transformer(args.model_name_or_path, max_seq_length=args.max_length) +pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.pooling_strategy) +model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) + +if args.push_to_hub: + model.push_to_hub(args.hub_model_id, private=args.hub_private_repo, exist_ok=True) From 4476bfd2a185ccd2f17f22fcf0987d53a78d3039 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:16:58 +0800 Subject: [PATCH 02/11] refactor and support espresso --- angle_emb/angle.py | 109 +++++++++++-------- angle_emb/{train_cli.py => angle_trainer.py} | 31 +++--- 2 files changed, 78 insertions(+), 62 deletions(-) rename angle_emb/{train_cli.py => angle_trainer.py} (92%) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index e2ccada..6053a7c 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -700,23 +700,42 @@ def __init__(self, def __call__(self, inputs: Dict, layer_index: int = -1, - embedding_start: int = 0, + embedding_start: Optional[int] = None, embedding_size: Optional[int] = None, - return_all_layer_outputs: bool = False) -> torch.Tensor: + return_all_layer_outputs: bool = False, + pooling_strategy: Optional[Union[int, str]] = None,) -> torch.Tensor: """ Get sentence embeddings. :param inputs: Dict. Model inputs. - :param layer_index: int. Get embeddings from specific layer. - :param embedding_size: int. Set embedding size for sentence embeddings for Espresso models. + :param layer_index: Optional[int]. Get embeddings from specific layer. + :param embedding_start: Optional[int]. Start index of embeddings. + :param embedding_size: int. Set embedding size for sentence embeddings. + :param return_all_layer_outputs: bool. Return all layer outputs or not. Default False. + :param pooling_strategy: Optional[str]. + Currently support [`cls`, `last`, `avg`, `cls_avg`, `max`]. Default None. """ all_layer_outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states if return_all_layer_outputs: return all_layer_outputs outputs = all_layer_outputs[layer_index] - outputs = get_pooling(outputs, inputs, self.pooling_strategy, padding_strategy=self.padding_strategy) - outputs = outputs[:, embedding_start:] + outputs = get_pooling(outputs, inputs, + pooling_strategy or self.pooling_strategy, + padding_strategy=self.padding_strategy) + n_dim = len(outputs.shape) + if embedding_start is not None: + if n_dim == 2: + outputs = outputs[:, embedding_start:] + elif n_dim == 3: + outputs = outputs[:, :, embedding_start:] + else: + raise ValueError(f'Unsupported output shape: {outputs.shape}') if embedding_size is not None: # topk embedding size - return outputs[:, :embedding_size] + if n_dim == 2: + outputs = outputs[:, :embedding_size] + elif n_dim == 3: + outputs = outputs[:, :, :embedding_size] + else: + raise ValueError(f'Unsupported output shape: {outputs.shape}') return outputs @@ -745,37 +764,37 @@ def __init__(self, self.teacher_name_or_path = teacher_name_or_path self.teacher_pooling_strategy = teacher_pooling_strategy if teacher_name_or_path is not None: - logger.info('fixed teacher detected! ' - 'please ensure the fixed teacher has the same tokenizer as the backbone model!') + logger.info('Teacher detected! ' + 'please ensure the teacher has the same tokenizer as the backbone model!') assert not check_llm(teacher_name_or_path), ('Currently not support LLMs alignment,' f' teacher={teacher_name_or_path}') - assert self.pooler.pooling_strategy == 'all', ('teacher_name_or_path detected!' - ' please set --pooling_strategy all') - fixed_teacher_backbone = AutoModel.from_pretrained( + teacher_backbone = AutoModel.from_pretrained( teacher_name_or_path, trust_remote_code=True, - torch_dtype="auto") + torch_dtype=self.pooler.model.dtype).to(self.pooler.model.device) - fixed_teacher_backbone.config.use_cache = False - self.fixed_teacher_pooler = Pooler( - fixed_teacher_backbone, - pooling_strategy='all', + self.teacher_pooler = Pooler( + teacher_backbone, + pooling_strategy=self.teacher_pooling_strategy, padding_strategy=self.pooler.padding_strategy) - logger.info(f'Train with alignment, teacher={teacher_name_or_path}') + logger.info(f'Train with teacher={teacher_name_or_path}') def distillation_loss(self, inputs: torch.Tensor, targets: torch.Tensor, + mse_weight: float = 1.0, kl_temperature: float = 1.0) -> torch.Tensor: """ Compute distillation loss. :param inputs: torch.Tensor. Input tensor. :param targets: torch.Tensor. Target tensor. + :param mse_weight: float. MSE weight. Default 1.0. :param kl_temperature: float. KL temperature. Default 1.0. :return: torch.Tensor. Distillation loss. """ loss = 0. - loss += nn.MSELoss()(inputs, targets) + if mse_weight > 0: + loss += mse_weight * nn.MSELoss()(inputs, targets) if kl_temperature > 0: loss += nn.KLDivLoss(reduction='batchmean')( F.log_softmax(inputs / kl_temperature, dim=-1), @@ -793,16 +812,20 @@ def compute_loss(self, model, inputs, return_outputs=False): """ labels = inputs.pop("labels", None) if self.teacher_name_or_path is not None: - all_outputs = self.pooler(inputs) + all_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True)[-1] outputs = get_pooling(all_outputs, inputs, - self.teacher_pooling_strategy, + self.pooler.pooling_strategy, self.pooler.padding_strategy) loss = self.loss_fct(labels, outputs) with torch.no_grad(): - self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device) - all_fixed_outputs = self.fixed_teacher_pooler(inputs) - - alignment_loss = self.distillation_loss(all_outputs, all_fixed_outputs) + self.teacher_pooler.model = self.teacher_pooler.model.to(self.pooler.model.device) + align_outputs = self.teacher_pooler(inputs) + + alignment_loss = self.distillation_loss( + all_outputs if self.teacher_pooling_strategy == 'all' else outputs, + align_outputs, + mse_weight=0.0, + kl_temperature=1.0) loss += alignment_loss else: outputs = self.pooler(inputs) @@ -840,33 +863,21 @@ def __init__(self, self.ese_compression_size = ese_compression_size self.apply_ese_pca = apply_ese_pca self.n_layers = self.pooler.model.config.num_hidden_layers - logger.info('Train with Espresso v5!') + logger.info('Train with โ˜•๏ธ Espresso!') @torch.no_grad() def pca_compress(self, m: torch.Tensor, k: int) -> torch.Tensor: - """ Get topk feature via quasi-SVD. + """ Get topk feature via PCA. :param m: torch.Tensor. Input tensor. :param k: int. Top-k feature size. :return: torch.Tensor. Top-k feature. """ A = F.softmax(m.T @ m / m.shape[-1]**0.5, dim=-1) u, s, _ = torch.svd_lowrank(A, q=k) - # a = u @ torch.diag(F.softmax(s, dim=-1)) @ (v.T)[:, :k] # top-k principal components topk_deps = u @ torch.diag(s) return m @ topk_deps - @torch.no_grad() - def pca_compress_old(self, m: torch.Tensor, k: int) -> torch.Tensor: - """ Get topk feature via quasi-SVD. - :param m: torch.Tensor. Input tensor. - :param k: int. Top-k feature size. - :return: torch.Tensor. Top-k feature. - """ - u, s, _ = torch.svd_lowrank(m, q=k) - # top-k principal components - return u @ torch.diag(s) - def compute_student_loss(self, inputs: Dict, all_layer_outputs: torch.Tensor, @@ -902,13 +913,10 @@ def compute_loss(self, model, inputs, return_outputs=False): """ labels = inputs.pop("labels", None) # layer - pooling_strategy = (self.teacher_pooling_strategy - if self.pooler.pooling_strategy == 'all' - else self.pooler.pooling_strategy) all_layer_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True) all_teacher_outputs = all_layer_outputs[-1] teacher_outputs = get_pooling(all_teacher_outputs, inputs, - pooling_strategy, + self.pooler.pooling_strategy, self.pooler.padding_strategy) loss = self.loss_fct(labels, teacher_outputs) @@ -927,17 +935,20 @@ def compute_loss(self, model, inputs, return_outputs=False): inputs, all_layer_outputs, labels, - pooling_strategy, + self.pooler.pooling_strategy, self.pooler.padding_strategy, ) # alignment loss if self.teacher_name_or_path is not None: with torch.no_grad(): - self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device) - all_fixed_outputs = self.fixed_teacher_pooler(inputs) + self.teacher_pooler.model = self.teacher_pooler.model.to(self.pooler.model.device) + align_outputs = self.teacher_pooler(inputs) alignment_loss = self.distillation_loss( - all_teacher_outputs, all_fixed_outputs, + all_teacher_outputs if self.teacher_pooling_strategy == 'all' else teacher_outputs, + align_outputs, + mse_weight=0.0, + kl_temperature=1.0 ) loss += alignment_loss return (loss, teacher_outputs) if return_outputs else loss @@ -1245,6 +1256,7 @@ def __init__(self, target_modules = find_all_linear_names(model) lora_config['target_modules'] = target_modules logger.info(f'lora target modules={target_modules}') + if pretrained_lora_path is not None: print(f'Load lora weight from {pretrained_lora_path}') model = PeftModel.from_pretrained( @@ -1562,7 +1574,8 @@ def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[floa y_trues.extend(y[::2, 0].detach().cpu().numpy()) with torch.no_grad(): X.to(device or self.device) - x_vecs = self.pooler(X).detach().float().cpu().numpy() + x_vecs = self.pooler(X, + pooling_strategy=self.pooling_strategy).detach().float().cpu().numpy() x_vecs = l2_normalize(x_vecs) pred = (x_vecs[::2] * x_vecs[1::2]).sum(1) y_preds.extend(pred) diff --git a/angle_emb/train_cli.py b/angle_emb/angle_trainer.py similarity index 92% rename from angle_emb/train_cli.py rename to angle_emb/angle_trainer.py index 99d7611..7cabb52 100644 --- a/angle_emb/train_cli.py +++ b/angle_emb/angle_trainer.py @@ -14,7 +14,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', type=str, required=True, - help='Specify model_name_or_path to set transformer backbone, default roberta-large') + help='Specify model name or path to set transformer backbone, required') parser.add_argument('--pretrained_model_path', type=str, default=None, help='Specify pretrained model path to load pretrained model, default None') parser.add_argument('--pretrained_lora_path', type=str, default=None, @@ -22,21 +22,24 @@ parser.add_argument('--train_name_or_path', type=str, required=True, help='Specify huggingface datasets name or local file path for train set, required') parser.add_argument('--train_subset_name', type=str, default=None, - help='Specify huggingface datasets subset name for train set') + help='Specify huggingface datasets subset name for train set, default None') parser.add_argument('--train_split_name', type=str, default='train', - help='Specify huggingface datasets split name for train set, Default `train`') + help='Specify huggingface datasets split name for train set, default `train`') parser.add_argument('--valid_name_or_path', type=str, default=None, - help='Specify huggingface datasets name or local file path for valid set.') + help='Specify huggingface datasets name or local file path for valid set, default None.') parser.add_argument('--valid_subset_name', type=str, default=None, - help='Specify huggingface datasets subset name for valid set') + help='Specify huggingface datasets subset name for valid set, default None') parser.add_argument('--valid_split_name', type=str, default='train', - help='Specify huggingface datasets split name for valid set') + help='Specify huggingface datasets split name for valid set, default `train`') parser.add_argument('--prompt_template', type=str, default=None, - help='Specify prompt_template like "Instruct: xxx\nInput: {text}", default None') + help='Specify prompt_template like "xxx: {text}", default None.' + 'This prompt will be applied for all text columns.' + 'If you want to specify different prompts for different text columns,' + 'please handle it in the preprocessing step.') parser.add_argument('--save_dir', type=str, default=None, help='Specify save dir, default None') -parser.add_argument('--seed', type=int, default=42, - help='Specify random seed, default 42') +parser.add_argument('--seed', type=int, default=-1, + help='Specify random seed, default -1') parser.add_argument('--dataset_seed', type=int, default=None, help='Specify dataset random seed, default None') parser.add_argument('--workers', type=int, default=2, @@ -54,7 +57,7 @@ parser.add_argument('--ibn_tau', type=float, default=20.0, help='Specify ibn_tau, defaut 20.0') parser.add_argument('--apply_lora', type=int, default=0, choices=[0, 1], - help='Specify apply_lora, choices [0, 1], defaut 0') + help='Specify lora flag, choices [0, 1], default 0') parser.add_argument('--load_kbit', type=int, default=None, choices=[4, 8, 16], help='Specify kbit training, choices [4, 8, 16], default None') parser.add_argument('--lora_r', type=int, default=32, @@ -64,7 +67,7 @@ parser.add_argument('--lora_dropout', type=float, default=0.1, help='Specify lora_dropout, defaut 0.1') parser.add_argument('--lora_target_modules', type=str, default=None, - help='Specify lora_target_modules. comma serves as the splitter, such as W,b. Defaut None') + help='Specify lora_target_modules. comma serves as the splitter, such as `W,b`. Defaut None') parser.add_argument('--learning_rate', type=float, default=1e-5, help='Specify learning_rate, defaut 1e-5') parser.add_argument('--warmup_steps', type=int, default=100, @@ -75,14 +78,14 @@ help='Specify pooling_strategy from [`cls`, `last`, `avg`, `cls_avg`, `max`], default `cls`') parser.add_argument('--tokenizer_padding_side', type=str, default=None, choices=['left', 'right'], help='specify tokenizer padding side from [`left`, `right`], default None') -parser.add_argument('--epochs', type=int, default=20, help='Specify epochs, default 20') +parser.add_argument('--epochs', type=int, default=10, help='Specify epochs, default 10') parser.add_argument('--max_steps', type=int, default=-1, help='Specify max steps, default -1 (Automatically calculated from epochs)') parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000') parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32') parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512') parser.add_argument('--streaming', action='store_true', default=False, - help='Flag to enable streaming mode (default: False)') + help='Flag to enable streaming mode, default False') parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Specify gradient_accumulation_steps, default 1') parser.add_argument('--torch_dtype', type=str, default=None, choices=['auto', 'float32', 'float16', 'bfloat16'], @@ -206,7 +209,7 @@ def main(): valid_ds = load_dataset(args.valid_name_or_path, num_proc=args.workers) valid_ds = valid_ds[args.valid_split_name or 'train'].map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), - num_proc=args.workers) + num_proc=args.workers) argument_kwargs = {} if args.push_to_hub: From 515ed6286ece0f09a3e96614280e9d32d9743c97 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:17:26 +0800 Subject: [PATCH 03/11] add docs --- docs/index.rst | 73 +++++++++++-- docs/notes/citation.rst | 38 +++++++ docs/notes/pretrained_models.rst | 36 ++++++ docs/notes/quick_start.rst | 78 ++++++++++++- docs/notes/training.rst | 181 +++++++++++++++++++++++++++++++ 5 files changed, 391 insertions(+), 15 deletions(-) create mode 100644 docs/notes/citation.rst diff --git a/docs/index.rst b/docs/index.rst index 33d9cdc..48f16a0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,18 +3,73 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to AnglE's documentation! +AnglE ๐Ÿ“ ================================= -.. toctree:: - :maxdepth: 2 - :caption: Contents: +.. image:: https://img.shields.io/badge/Arxiv-2309.12871-yellow.svg?style=flat-square + :target: https://arxiv.org/abs/2309.12871 + +.. image:: https://img.shields.io/pypi/v/angle_emb?style=flat-square + :alt: PyPI version + :target: https://pypi.org/project/angle_emb/ + +.. image:: https://img.shields.io/pypi/dm/angle_emb?style=flat-square + :alt: PyPI version + :target: https://pypi.org/project/angle_emb/ + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sick-r-1 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sick-r-1?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts16 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts16?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts15 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts15?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts14 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts14?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts13 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts13?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts12 + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts12?p=angle-optimized-text-embeddings + +.. image:: https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts-benchmark + :target: https://paperswithcode.com/sota/semantic-textual-similarity-on-sts-benchmark?p=angle-optimized-text-embeddings + +๐Ÿ“ข **Train/Infer Powerful Sentence Embeddings with AnglE.** +This library is from the paper `Angle-optimized Text Embeddings `_ . +It allows you to train state-of-the-art BERT/LLM-based sentence embeddings with just a few lines of code. +AnglE is also a general sentence embedding inference framework, allowing for infering a variety of transformer-based sentence embeddings. -Indices and tables -================== -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` +โœจ Features +-------------- + + +**Loss**: + +1) ๐Ÿ“ AnglE loss +2) โš– Contrastive loss +3) ๐Ÿ“ CoSENT loss +4) โ˜•๏ธ Espresso loss (previously known as 2DMSE) + +**Backbones**: + +1) BERT-based models (BERT, RoBERTa, ELECTRA, ALBERT, etc.) +2) LLM-based models (LLaMA, Mistral, Qwen, etc.) +3) Bi-directional LLM-based models (LLaMA, Mistral, Qwen, OpenELMo, etc.. refer to: https://github.com/WhereIsAI/BiLLM) + +**Training**: + +1) Single-GPU training +2) Multi-GPU training + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + diff --git a/docs/notes/citation.rst b/docs/notes/citation.rst new file mode 100644 index 0000000..f05b661 --- /dev/null +++ b/docs/notes/citation.rst @@ -0,0 +1,38 @@ +๐Ÿซก Citation +=================== + +You are welcome to use our code and pre-trained models. If you use our code and pre-trained models, please support us by citing our work as follows: + +.. code-block:: bibtex + + @article{li2023angle, + title={AnglE-optimized Text Embeddings}, + author={Li, Xianming and Li, Jing}, + journal={arXiv preprint arXiv:2309.12871}, + year={2023} + } + + +If you train with Espresso technique, please also cite the following paper: + +.. code-block:: bibtex + + @article{li20242d, + title={ESE: Espresso Sentence Embeddings}, + author={Xianming Li and Zongxi Li and Jing Li and Haoran Xie and Qing Li}, + journal={arXiv preprint arXiv:2402.14776}, + year={2024} + } + + +If you train with Bi-directional LLMs, please also cite the following paper: + +.. code-block:: bibtex + + @inproceedings{li2024bellm, + title = "BeLLM: Backward Dependency Enhanced Large Language Model for Sentence Embeddings", + author = "Li, Xianming and Li, Jing", + booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics", + year = "2024", + publisher = "Association for Computational Linguistics" + } \ No newline at end of file diff --git a/docs/notes/pretrained_models.rst b/docs/notes/pretrained_models.rst index e69de29..0e0dc36 100644 --- a/docs/notes/pretrained_models.rst +++ b/docs/notes/pretrained_models.rst @@ -0,0 +1,36 @@ +๐Ÿ›๏ธ Official Pretrained Models +=========================== + + + +BERT-based models: +------------------ + ++------------------------------------+-------------+-------------------+--------------------------+ +| ๐Ÿค— HF | Max Tokens | Pooling Strategy | Scenario | ++====================================+=============+===================+==========================+ +| `WhereIsAI/UAE-Large-V1`_ | 512 | cls | English, General purpose | ++------------------------------------+-------------+-------------------+--------------------------+ +| `WhereIsAI/UAE-Code-Large-V1`_ | 512 | cls | Code Similarity | ++------------------------------------+-------------+-------------------+--------------------------+ + +.. _WhereIsAI/UAE-Large-V1: https://huggingface.co/WhereIsAI/UAE-Large-V1 +.. _WhereIsAI/UAE-Code-Large-V1: https://huggingface.co/WhereIsAI/UAE-Code-Large-V1 + + +LLM-based models: +----------------- + ++------------------------------------+-----------------------------+------------------+--------------------------+------------------+---------------------------------+ +| ๐Ÿค— HF (lora weight) | Backbone | Max Tokens | Prompts | Pooling Strategy | Scenario | ++====================================+=============================+==================+==========================+==================+=================================+ +| `SeanLee97/angle-llama-13b-nli`_ | NousResearch/Llama-2-13b-hf | 4096 | ``Prompts.A`` | last token | English, Similarity Measurement | ++------------------------------------+-----------------------------+------------------+--------------------------+------------------+---------------------------------+ +| `SeanLee97/angle-llama-7b-nli-v2`_ | NousResearch/Llama-2-7b-hf | 4096 | ``Prompts.A`` | last token | English, Similarity Measurement | ++------------------------------------+-----------------------------+------------------+--------------------------+------------------+---------------------------------+ + +.. _SeanLee97/angle-llama-13b-nli: https://huggingface.co/SeanLee97/angle-llama-13b-nli +.. _SeanLee97/angle-llama-7b-nli-v2: https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2 + + +๐Ÿ“ข More pretrained models are coming soon! \ No newline at end of file diff --git a/docs/notes/quick_start.rst b/docs/notes/quick_start.rst index 5eceeab..53f9e5c 100644 --- a/docs/notes/quick_start.rst +++ b/docs/notes/quick_start.rst @@ -1,9 +1,11 @@ ๐Ÿš€ Quick Start ================================ -A few lines of code to generate sentence embeddings using AnglE. +A few steps steps to get started with AnglE: -1) Install the latest AnglE as follows: + +โฌ‡๏ธ Installation +------------------------------------ .. code-block:: bash @@ -12,11 +14,75 @@ A few lines of code to generate sentence embeddings using AnglE. Other installation methods, please refer to the `Installation` section. -2) Load pretrained models and encode text. +โŒ› Load BERT-based Model +------------------------------------ + +1) **With Prompts**: You can specify a prompt with `prompt=YOUR_PROMPT` in `encode` method. +If set a prompt, the inputs should be a list of dict or a single dict with key `text`, where `text` is the placeholder in the prompt for the input text. +You can use other placeholder names. We provide a set of predefined prompts in `Prompts` class, you can check them via `Prompts.list_prompts()`. + + +.. code-block:: python + + from angle_emb import AnglE, Prompts + from angle_emb.utils import cosine_similarity + + + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() + # For retrieval tasks, we use `Prompts.C` as the prompt for the query when using UAE-Large-V1 (no need to specify prompt for documents). + # When specify prompt, the inputs should be a list of dict with key 'text' + qv = angle.encode({'text': 'what is the weather?'}, to_numpy=True, prompt=Prompts.C) + doc_vecs = angle.encode([ + 'The weather is great!', + 'it is rainy today.', + 'i am going to bed' + ], to_numpy=True) + + for dv in doc_vecs: + print(cosine_similarity(qv[0], dv)) + + +2) **Without Prompts**: no need to specify a prompt. Just input a list of strings or a single string. + .. code-block:: python from angle_emb import AnglE - - angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1').cuda() - vec = angle.encode("I'll take a break.") + from angle_emb.utils import cosine_similarity + + + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() + # for non-retrieval tasks, we don't need to specify prompt when using UAE-Large-V1. + doc_vecs = angle.encode([ + 'The weather is great!', + 'The weather is very good!', + 'i am going to bed' + ]) + + for i, dv1 in enumerate(doc_vecs): + for dv2 in doc_vecs[i+1:]: + print(cosine_similarity(dv1, dv2)) + + + +โŒ› Load LLM-based Models +------------------------------------ + +If the pretrained weight is a LoRA-based model, you need to specify the backbone via `model_name_or_path` and specify the LoRA path via the `pretrained_lora_path` in `from_pretrained` method. + +.. code-block:: python + + from angle_emb import AnglE, Prompts + + angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', + pretrained_lora_path='SeanLee97/angle-llama-7b-nli-v2', + pooling_strategy='last', + is_llm=True, + torch_dtype='float16') + + print('All predefined prompts:', Prompts.list_prompts()) + vec = angle.encode({'text': 'hello world'}, to_numpy=True, prompt=Prompts.A) + print(vec) + vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True, prompt=Prompts.A) + print(vecs) + diff --git a/docs/notes/training.rst b/docs/notes/training.rst index e69de29..3f56c41 100644 --- a/docs/notes/training.rst +++ b/docs/notes/training.rst @@ -0,0 +1,181 @@ +๐Ÿš‚ Training and Finetuning +============================ + +There are two types of training methods: + +1. use the `angle-trainer` cli to train a model in cli mode. +2. custom training scripts using the `angle` library. + + +๐Ÿ—‚๏ธ Data Prepration +---------------------------------- + +We currently support three dataset formats: + +1. `DatasetFormats.A`: it is a pair format with three columns: `text1`, `text2`, and `label` (0/1). e.g. `{"text1": "A plane is taking off.", "text2": "An air plane is taking off.", "label": 1}` + +2. `DatasetFormats.B`: it is a triple format with three columns: `text`, `positive`, and `negative`. `positive` and `negative` are the positive and negative samples of `text`. e.g. `{"text": "A person on a horse jumps over a broken down airplane.", "positive": "A person is outdoors, on a horse.", "negative": "A person is at a diner, ordering an omelette."}` + +3. `DatasetFormats.C`: it is a pair format with two columns: `text`, `positive`. `positive` is the positive sample of `text`. e.g. `{"text": "Two blond women are hugging one another.", "positive": "There are women showing affection."}` + +It is required toprepare your data into huggingface `datasets.Dataset` in one of the above formats. + + +โญ angle-trainer [recommended] +---------------------------------- + +You can train a powerful sentence embedding model using the `angle-trainer` cli via a few lines of code. + +1. Single gpu training: + + Usage: + + .. code-block:: bash + + CUDA_VISIBLE_DEVICES=0 angle-trainer --help + +2. Multi-gpu training: + + Usage: + + .. code-block:: bash + + CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --master_port=1234 -m angle_emb.angle_trainer --help + + +3. Examples: + + a. BERT-based + + .. code-block:: bash + + BiLLM_START_INDEX=0 WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=2345 -m angle_emb.angle_trainer \ + --train_name_or_path SeanLee97/all_nli_angle_format_b \ + --save_dir ckpts/billm-uae-large-nli \ + --model_name WhereIsAI/UAE-Large-V1 \ + --pooling_strategy cls \ + --maxlen 75 \ + --ibn_w 20.0 \ + --cosine_w 0.0 \ + --angle_w 1.0 \ + --learning_rate 1e-6 \ + --push_to_hub 1 --hub_model_id SeanLee97/test-uae-large-nli --hub_private_repo 1 \ + --logging_steps 5 \ + --save_steps 50 \ + --warmup_steps 50 \ + --batch_size 64 \ + --seed 42 \ + --gradient_accumulation_steps 4 \ + --epochs 1 \ + --fp16 1 + + + + b. LLaMA-based + + .. code-block:: bash + + BiLLM_START_INDEX=0 WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=2345 -m angle_emb.angle_trainer \ + --train_name_or_path SeanLee97/all_nli_angle_format_b \ + --save_dir ckpts/billm-llama7b-nli \ + --model_name NousResearch/Llama-2-7b-chat-hf \ + --pooling_strategy avg \ + --maxlen 60 \ + --ibn_w 20.0 \ + --cosine_w 0.0 \ + --angle_w 1.0 \ + --learning_rate 2e-4 \ + --apply_lora 1 --lora_r 64 --lora_alpha 128 --lora_dropout 0.1 \ + --load_kbit 4 \ + --is_llm 1 \ + --apply_billm 1 \ + --billm_model_class LlamaForCausalLM \ + --push_to_hub 1 --hub_model_id SeanLee97/test-billm-llama7b-nli --hub_private_repo 1 \ + --logging_steps 5 \ + --save_steps 50 \ + --warmup_steps 50 \ + --batch_size 120 \ + --gradient_accumulation_steps 32 \ + --epochs 2 \ + --fp16 1 + + +๐Ÿš‚ Custom Train +---------------------------------- + +You can also train a sentence embedding model using the `angle_emb` library. Here is an example: + +.. code-block:: python + + from datasets import load_dataset + from angle_emb import AnglE, AngleDataTokenizer + + + # 1. load pretrained model + angle = AnglE.from_pretrained('SeanLee97/angle-bert-base-uncased-nli-en-v1', max_length=128, pooling_strategy='cls').cuda() + + # 2. load dataset + # `text1`, `text2`, and `label` are three required columns. + ds = load_dataset('mteb/stsbenchmark-sts') + ds = ds.map(lambda obj: {"text1": str(obj["sentence1"]), "text2": str(obj['sentence2']), "label": obj['score']}) + ds = ds.select_columns(["text1", "text2", "label"]) + + # 3. transform data + train_ds = ds['train'].shuffle().map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8) + valid_ds = ds['validation'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8) + test_ds = ds['test'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8) + + # 4. fit + angle.fit( + train_ds=train_ds, + valid_ds=valid_ds, + output_dir='ckpts/sts-b', + batch_size=32, + epochs=5, + learning_rate=2e-5, + save_steps=100, + eval_steps=1000, + warmup_steps=0, + gradient_accumulation_steps=1, + loss_kwargs={ + 'cosine_w': 1.0, + 'ibn_w': 1.0, + 'angle_w': 1.0, + 'cosine_tau': 20, + 'ibn_tau': 20, + 'angle_tau': 20 + }, + fp16=True, + logging_steps=100 + ) + + # 5. evaluate + corrcoef, accuracy = angle.evaluate(test_ds, device=angle.device) + print('corrcoef:', corrcoef) + + +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1h28jHvv_x-0fZ0tItIMjf8rJGp3GcO5V?usp=sharing + :alt: Open In Colab + + +๐Ÿ’ก 4. Fine-tuning Tips +------------------------- + +1. If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. + +2. If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` is recommended to set to 20.0. + +3. If your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. + +4. To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **Note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results. + + + +๐Ÿ’ก Others +------------------------- + +1. To enable `llm` training, please specify `--is_llm 1` and configure appropriate LoRA hyperparameters. +2. To enable `billm` training, please specify `--apply_billm 1` and configure appropriate `billm_model_class` such as `LLamaForCausalLM` (refer to: https://github.com/WhereIsAI/BiLLM?tab=readme-ov-file#usage). +3. To enable espresso sentence embeddings (ESE), please specify `--apply_ese 1` and configure appropriate ESE hyperparameters via `--ese_kl_temperature float` and `--ese_compression_size integer`. +4. To convert the trained AnglE models to `sentence-transformers`, please run `python scripts/convert_to_sentence_transformers.py --help` for more details. From 00819f2fb6f5d2b07b9260cdcdae30a43476a6f1 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:20:44 +0800 Subject: [PATCH 04/11] overhaul README --- README.md | 180 +++++++++++++++++++++++------------------ README_2DMSE.md | 4 +- README_ESE.md | 49 +++++++++++ README_Espresso.md | 0 examples/NLI/README.md | 14 ++++ 5 files changed, 166 insertions(+), 81 deletions(-) create mode 100644 README_ESE.md delete mode 100644 README_Espresso.md diff --git a/README.md b/README.md index 134132a..0e0ab3a 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,45 @@ EN | [็ฎ€ไฝ“ไธญๆ–‡](README_zh.md) -# [AnglE๐Ÿ“: Angle-optimized Text Embeddings](https://arxiv.org/abs/2309.12871) +# AnglE๐Ÿ“ > It is Angle ๐Ÿ“, not Angel ๐Ÿ‘ผ. -๐Ÿ“ข **Train/Infer Powerful Sentence Embedding Models with AnglE.** -AnglE enables you to train state-of-the-art BERT-based or LLM-based sentence embeddings with just a few lines of code. -AnglE is also a general inference framework for sentence embedding, allowing you to infer a variety of transformer-based sentence embeddings. +๐Ÿ“ข **Train/Infer Powerful Sentence Embeddings with AnglE.** +This library is from the paper: [AnglE: Angle-optimized Text Embeddings](https://arxiv.org/abs/2309.12871)
+ https://arxiv.org/abs/2309.12871 +. It allows you to train state-of-the-art BERT/LLM-based sentence embeddings with just a few lines of code. AnglE is also a general sentence embedding inference framework, allowing you to infer a variety of transformer-based sentence embeddings. -## ๐Ÿ† Achievements +## โœจ Features - - https://arxiv.org/abs/2309.12871 +**Loss**: +- ๐Ÿ“ AnglE loss +- โš– Contrastive loss +- ๐Ÿ“ CoSENT loss +- โ˜•๏ธ Espresso loss (previously known as 2DMSE, detail: [README_ESE](README_ESE.md)) + +**Backbones**: +- BERT-based models (BERT, RoBERTa, ELECTRA, ALBERT, etc.) +- LLM-based models (LLaMA, Mistral, Qwen, etc.) +- Bi-directional LLM-based models (LLaMA, Mistral, Qwen, OpenELMo, etc.. refer to: https://github.com/WhereIsAI/BiLLM) + +**Training**: +- Single-GPU training +- Multi-GPU training + + +> More features will be added in the future. + http://makeapullrequest.com + +## ๐Ÿ† Achievements + PyPI version PyPI Downloads - - http://makeapullrequest.com - [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sick-r-1)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sick-r-1?p=angle-optimized-text-embeddings) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts16)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts16?p=angle-optimized-text-embeddings) @@ -33,58 +50,48 @@ AnglE is also a general inference framework for sentence embedding, allowing you [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/angle-optimized-text-embeddings/semantic-textual-similarity-on-sts-benchmark)](https://paperswithcode.com/sota/semantic-textual-similarity-on-sts-benchmark?p=angle-optimized-text-embeddings) -๐Ÿ“… Mar 13, 2024 | Paper "[BeLLM: Backward Dependency Enhanced Large Language Model for Sentence Embeddings](https://arxiv.org/abs/2311.05296)" accepted by NAACL 2024 Main Conference. +๐Ÿ“… May 16, 2024 | Paper "[AnglE: Angle-optimized Text Embeddings](https://arxiv.org/abs/2309.12871)" is accepted by ACL 2024 Main Conference. +๐Ÿ“… Mar 13, 2024 | Paper "[BeLLM: Backward Dependency Enhanced Large Language Model for Sentence Embeddings](https://arxiv.org/abs/2311.05296)" is accepted by NAACL 2024 Main Conference. -๐Ÿ“… Mar 8, 2024 | ๐Ÿž [mixedbread's embedding](https://www.mixedbread.ai/blog/mxbai-embed-large-v1) ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of **64.68**! The model is trained using AnglE. + +๐Ÿ“… Mar 8, 2024 | ๐Ÿž [mixedbread's embedding](https://www.mixedbread.ai/blog/mxbai-embed-large-v1) ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of **64.68**! The model is trained using AnglE. Congrats mixedbread! ๐Ÿ“… Dec 4, 2023 | Our universal sentence embedding [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) achieves SOTA on the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) with an average score of **64.64**! The model is trained using AnglE. -๐Ÿ“… Dec, 2023 | **A New SOTA** for Semantic Textual Similarity! +๐Ÿ“… Dec, 2023 | AnglE achieves SOTA performance on the STS Bechmark Semantic Textual Similarity! ## ๐Ÿค— Official Pretrained Models -| ๐Ÿค— HF | LoRA Weight | Dependent Backbone | LLM | Language | Prompt | Pooling Strategy | Examples | -|----|------|------|------|------|------|------|------| -| [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | N | N | N | EN | `Prompts.C` for retrieval purposes, `None` for others | cls | [![Seach Demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WOYD6f8gb_wpkUm_57K8pEDgjlGJd6oB?usp=drive_link) | -| [SeanLee97/angle-llama-13b-nli](https://huggingface.co/SeanLee97/angle-llama-13b-nli) | Y | NousResearch/Llama-2-13b-hf | Y | EN | `Prompts.A` | last token | / | -| [SeanLee97/angle-llama-7b-nli-v2](https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2) | Y | NousResearch/Llama-2-7b-hf | Y | EN | `Prompts.A` | last token | / | -| [SeanLee97/angle-bert-base-uncased-nli-en-v1](https://huggingface.co/SeanLee97/angle-bert-base-uncased-nli-en-v1) | N | N | N | EN | N | `cls_avg` | / | - -
๐Ÿ’ก Tips -๐Ÿ’ก If the selected model is a LoRA weight, it must specify the corresponding dependent backbone. - -For our STS Experiment, please refer to https://github.com/SeanLee97/AnglE/tree/main/examples/NLI -
+BERT-based models: -## Results +| ๐Ÿค— HF | Max Tokens | Pooling Strategy | Scenario | +|----|------|------|------| +| [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | 512 | cls | English, General-purpose | +| [WhereIsAI/UAE-Code-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | 512 | cls | Code Similarity | -### English STS Results +LLM-based models: -| Model | STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. | -| ------- |-------|-------|-------|-------|-------|--------------|-----------------|-------| -| [SeanLee97/angle-llama-7b-nli-20231027](https://huggingface.co/SeanLee97/angle-llama-7b-nli-20231027) | 78.68 | 90.58 | 85.49 | 89.56 | 86.91 | 88.92 | 81.18 | 85.90 | -| [SeanLee97/angle-llama-7b-nli-v2](https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2) | 79.00 | 90.56 | 85.79 | 89.43 | 87.00 | 88.97 | 80.94 | 85.96 | -| [SeanLee97/angle-llama-13b-nli](https://huggingface.co/SeanLee97/angle-llama-13b-nli) | 79.33 | 90.65 | 86.89 | 90.45 | 87.32 | 89.69 | 81.32 | **86.52** | -| [SeanLee97/angle-bert-base-uncased-nli-en-v1](https://huggingface.co/SeanLee97/angle-bert-base-uncased-nli-en-v1) | 75.09 | 85.56 | 80.66 | 86.44 | 82.47 | 85.16 | 81.23 | 82.37 | +| ๐Ÿค— HF (lora weight) | Backbone | Max Tokens | Prompts | Pooling Strategy | Scenario | +|----|------|------|------|------|------| +| [SeanLee97/angle-llama-13b-nli](https://huggingface.co/SeanLee97/angle-llama-13b-nli) | NousResearch/Llama-2-13b-hf | 4096 | `Prompts.A` | last token | English, Similarity Measurement | +| [SeanLee97/angle-llama-7b-nli-v2](https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2) | NousResearch/Llama-2-7b-hf | 4096 | `Prompts.A` | last token | English, Similarity Measurement | ## ๐Ÿš€ Quick Start -AnglE supports two APIs, one is the `transformers` API, the other is the `AnglE` API. If you want to use the `AnglE` API, please install AnglE first: +### โฌ‡๏ธ Installation ```bash python -m pip install -U angle-emb ``` -### 1. Load BERT-based Models +### โŒ› Load BERT-based Model -1) For Retrieval Purposes - -For retrieval purposes, please use the prompt `Prompts.C` for query (not document). +1) **With Prompts**: You can specify a prompt with `prompt=YOUR_PROMPT` in `encode` method. If set a prompt, the inputs should be a list of dict or a single dict with key `text`, where `text` is the placeholder in the prompt for the input text. You can use other placeholder names. We provide a set of predefined prompts in `Prompts` class, you can check them via `Prompts.list_prompts()`. ```python from angle_emb import AnglE, Prompts @@ -92,7 +99,8 @@ from angle_emb.utils import cosine_similarity angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() -# when specify prompt, the inputs should be a list of dict with key 'text' +# For retrieval tasks, we use `Prompts.C` as the prompt for the query when using UAE-Large-V1 (no need to specify prompt for documents). +# When specify prompt, the inputs should be a list of dict with key 'text' qv = angle.encode({'text': 'what is the weather?'}, to_numpy=True, prompt=Prompts.C) doc_vecs = angle.encode([ 'The weather is great!', @@ -104,7 +112,7 @@ for dv in doc_vecs: print(cosine_similarity(qv[0], dv)) ``` -2) For non-Retrieval Purposes +2) **Without Prompts**: no need to specify a prompt. Just input a list of strings or a single string. ```python from angle_emb import AnglE @@ -112,6 +120,7 @@ from angle_emb.utils import cosine_similarity angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda() +# for non-retrieval tasks, we don't need to specify prompt when using UAE-Large-V1. doc_vecs = angle.encode([ 'The weather is great!', 'The weather is very good!', @@ -123,37 +132,19 @@ for i, dv1 in enumerate(doc_vecs): print(cosine_similarity(dv1, dv2)) ``` -
-Difference between retrieval and non-retrieval sentence embeddings. [click to expand] - -In UAE, we use different approaches for retrieval and non-retrieval tasks, each serving a different purpose. - -**Retrieval tasks aim to find relevant documents, and as a result, the related documents may not have strict semantic similarities to each other.** - -For instance, when querying "How about ChatGPT?", the related documents are those that contain information related to "ChatGPT," such as "ChatGPT is amazing..." or "ChatGPT is bad....". - -Conversely, **non-retrieval tasks, such as semantic textual similarity, require sentences that are semantically similar.** -For example, a sentence semantically similar to "How about ChatGPT?" could be "What is your opinion about ChatGPT?". +### โŒ› Load LLM-based Models -To distinguish between these two types of tasks, we use different prompts. +If the pretrained weight is a LoRA-based model, you need to specify the backbone via `model_name_or_path` and specify the LoRA path via the `pretrained_lora_path` in `from_pretrained` method. -For retrieval tasks, we use the prompt "Represent this sentence for searching relevant passages: {text}" (Prompts.C in angle_emb) for the query (**no need to apply it for the documents**). - -For non-retrieval tasks, we set the prompt to empty, i.e., just input your text without specifying a prompt. - -So, if your scenario is retrieval-related, it is highly recommended to set the prompt with angle.set_prompt(prompt=Prompts.C). If not, leave the prompt empty or use angle.set_prompt(prompt=None). -
- -### 2. Load LoRA-based Models - -1) AnglE ```python from angle_emb import AnglE, Prompts angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', pretrained_lora_path='SeanLee97/angle-llama-7b-nli-v2', - pooling_strategy='last') + pooling_strategy='last', + is_llm=True, + torch_dtype='float16') print('All predefined prompts:', Prompts.list_prompts()) vec = angle.encode({'text': 'hello world'}, to_numpy=True, prompt=Prompts.A) @@ -162,7 +153,7 @@ vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_num print(vecs) ``` -### 3. Load Third-party Models w/ angle_emb +### โŒ› Load Third-party Models You can load any transformer-based third-party models such as `mixedbread-ai/mxbai-embed-large-v1`, `sentence-transformers/all-MiniLM-L6-v2`, and `BAAI/bge-large-en-v1.5` using `angle_emb`. @@ -177,11 +168,11 @@ print(vec) ``` -## Custom Train +## ๐Ÿ•ธ๏ธ Custom Train -### 1. Data Prepation +### ๐Ÿ—‚๏ธ 1. Data Prepation -We support two dataset formats: +We currently support three dataset formats: 1) `DatasetFormats.A`: it is a pair format with three columns: `text1`, `text2`, and `label` (0/1). @@ -191,12 +182,27 @@ We support two dataset formats: You need to prepare your data into huggingface `datasets.Dataset` in one of the formats in terms of your supervised data. -### 2. Train +### ๐Ÿš‚ 2. Train with CLI + +Use `angle-trainer` to train your AnglE model in cli mode. + +1) Single gpu training: + +Usage: + +```bash +CUDA_VISIBLE_DEVICES=0 angle-trainer --help +``` + +2) Multi-gpu training: -Use `angle-trainer` to train your AnglE model in cli mode. Usage: `CUDA_VISIBLE_DEVICES=0 angle-trainer --help` +Usage: +```bash +CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --master_port=1234 -m angle_emb.angle_trainer --help +``` -### 3. Example +### ๐Ÿš‚ 3. Custom Train [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1h28jHvv_x-0fZ0tItIMjf8rJGp3GcO5V?usp=sharing) @@ -238,7 +244,7 @@ angle.fit( 'angle_w': 1.0, 'cosine_tau': 20, 'ibn_tau': 20, - 'angle_tau': 20.0 + 'angle_tau': 20 }, fp16=True, logging_steps=100 @@ -249,18 +255,26 @@ corrcoef, accuracy = angle.evaluate(test_ds, device=angle.device) print('corrcoef:', corrcoef) ``` -### 4. Fine-tuning Tips ๐Ÿ’ก +### ๐Ÿ’ก Others + +- To enable `llm` training, please specify `--is_llm 1` and configure appropriate LoRA hyperparameters. +- To enable `billm` training, please specify `--apply_billm 1` and configure appropriate `billm_model_class` such as `LLamaForCausalLM` (refer to: https://github.com/WhereIsAI/BiLLM?tab=readme-ov-file#usage). +- To enable espresso sentence embeddings (ESE), please specify `--apply_ese 1` and configure appropriate ESE hyperparameters via `--ese_kl_temperature float` and `--ese_compression_size integer`. +- To convert the trained AnglE models to `sentence-transformers`, please run `python scripts/convert_to_sentence_transformers.py --help` for more details. + -1) if your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. +## ๐Ÿ’ก 4. Fine-tuning Tips -2) if your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` can be set to 20.0. +1๏ธโƒฃ If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. -3) if your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. +2๏ธโƒฃ If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` is recommended to set to 20.0. -4) To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **It is worth to note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results. +3๏ธโƒฃ If your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. +4๏ธโƒฃ To alleviate information forgetting in fine-tuning, it is better to specify the `teacher_name_or_path`. If the `teacher_name_or_path` equals `model_name_or_path`, it will conduct self-distillation. **It is worth to note that** `teacher_name_or_path` has to have the same tokenizer as `model_name_or_path`. Or it will lead to unexpected results. -# Citation + +# ๐Ÿซก Citation You are welcome to use our code and pre-trained models. If you use our code and pre-trained models, please support us by citing our work as follows: @@ -273,12 +287,20 @@ You are welcome to use our code and pre-trained models. If you use our code and } ``` -# ChangeLogs +# ๐Ÿ“œ ChangeLogs | ๐Ÿ“… | Description | |----|------| | 2024 Feb 7 | support training with only positive pairs (`DatasetFormats.C`) | -| 2024 Jan 11 | refactor to support `angle-trainer` and BeLLM | | 2023 Dec 4 | Release a universal English sentence embedding model: [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | | 2023 Nov 2 | Release an English pretrained model: `SeanLee97/angle-llama-13b-nli` | | 2023 Oct 28 | Release two chinese pretrained models: `SeanLee97/angle-roberta-wwm-base-zhnli-v1` and `SeanLee97/angle-llama-7b-zhnli-v1`; Add chinese README.md | + +# ๐Ÿ“ง Contact + +If you have any questions or suggestions, please feel free to contact us via email: xmlee97@gmail.com + +# ยฉ License + +This project is licensed under the MIT License. +For the pretrained models, please refer to the corresponding license of the models. \ No newline at end of file diff --git a/README_2DMSE.md b/README_2DMSE.md index 4d7ccb8..722da0b 100644 --- a/README_2DMSE.md +++ b/README_2DMSE.md @@ -2,7 +2,7 @@ > Paper: https://arxiv.org/abs/2402.14776 -"๐Ÿช† 2D Matryoshka Sentence Embeddings" has been renamed to "โ˜•๏ธ Espresso Sentence Embeddings". +"๐Ÿช† 2D Matryoshka Sentence Embeddings" has been renamed to โ˜•๏ธ "ESE: Espresso Sentence Embeddings". Please find the document in [โ˜•๏ธ Espresso](README_Espresso.md) @@ -11,7 +11,7 @@ Please find the document in [โ˜•๏ธ Espresso](README_Espresso.md) ```bibtex @article{li20242d, - title={2D Matryoshka Sentence Embeddings}, + title={ESE: Espresso Sentence Embeddings}, author={Xianming Li and Zongxi Li and Jing Li and Haoran Xie and Qing Li}, journal={arXiv preprint arXiv:2402.14776}, year={2024} diff --git a/README_ESE.md b/README_ESE.md new file mode 100644 index 0000000..6cdfb6a --- /dev/null +++ b/README_ESE.md @@ -0,0 +1,49 @@ +# Espresso Sentence Embeddings (previously known as 2DMSE) + +> Paper: https://arxiv.org/abs/2402.14776 + +## Abstract + +High-quality sentence embeddings are fundamental in many natural language processing (NLP) tasks, such as semantic textual similarity (STS) and retrieval-augmented generation (RAG). +Nevertheless, most existing methods leverage fixed-length embeddings from full-layer language models, which lack the scalability to accommodate the diverse available resources across various applications. +Viewing this gap, we propose a novel sentence embedding model $\mathrm{Espresso}$ $\mathrm{Sentence}$ $\mathrm{Embeddings}$ (ESE) with two learning processes. +First, the **learn-to-express** process encodes more salient representations to lower layers. +Second, the **learn-to-compress** process compacts essential features into the initial dimensions using Principal Component Analysis (PCA). +This way, ESE can scale model depth via the former process and embedding size via the latter. +Extensive experiments on STS and RAG suggest that ESE can effectively produce high-quality embeddings with less model depth and embedding size, enhancing embedding inference efficiency. + +## How to train + +To enable espresso sentence embeddings (ESE), please specify `--apply_ese 1` and configure appropriate ESE hyperparameters via `--ese_kl_temperature float` and `--ese_compression_size integer`. + +Here is an training example: + +```bash +WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=1234 -m angle_emb.angle_trainer \ +--model_name_or_path WhereIsAI/UAE-Large-V1 \ +--train_name_or_path SeanLee97/nli_for_simcse --save_dir ckpts/UAE-Large-Espresso \ +--ibn_w 10.0 --cosine_w 0. --angle_w 1.0 --angle_tau 20.0 --learning_rate 1e-6 --maxlen 75 \ +--workers 16 \ +--pooling_strategy cls \ +--epochs 1 \ +--batch_size 128 \ +--logging_steps 100 \ +--warmup_steps 200 \ +--save_steps 1000 \ +--fp16 1 \ +--gradient_accumulation_steps 4 \ +--apply_ese 1 \ +--ese_compression_size 128 \ +--ese_kl_temperature 1.0 +``` + +# Citation + +```bibtex +@article{li20242d, + title={ESE: Espresso Sentence Embeddings}, + author={Xianming Li and Zongxi Li and Jing Li and Haoran Xie and Qing Li}, + journal={arXiv preprint arXiv:2402.14776}, + year={2024} +} +``` \ No newline at end of file diff --git a/README_Espresso.md b/README_Espresso.md deleted file mode 100644 index e69de29..0000000 diff --git a/examples/NLI/README.md b/examples/NLI/README.md index 3a6f04c..caaa4c9 100644 --- a/examples/NLI/README.md +++ b/examples/NLI/README.md @@ -1,3 +1,17 @@ +# ๐Ÿค— HF Pretrained Models + +[AnglE NLI Sentence Embedding](https://huggingface.co/collections/SeanLee97/angle-nli-sentence-embeddings-6646de386099d0472c5e21c0) + +# English STS Results + +| Model | STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. | +| ------- |-------|-------|-------|-------|-------|--------------|-----------------|-------| +| [SeanLee97/angle-llama-7b-nli-20231027](https://huggingface.co/SeanLee97/angle-llama-7b-nli-20231027) | 78.68 | 90.58 | 85.49 | 89.56 | 86.91 | 88.92 | 81.18 | 85.90 | +| [SeanLee97/angle-llama-7b-nli-v2](https://huggingface.co/SeanLee97/angle-llama-7b-nli-v2) | 79.00 | 90.56 | 85.79 | 89.43 | 87.00 | 88.97 | 80.94 | 85.96 | +| [SeanLee97/angle-llama-13b-nli](https://huggingface.co/SeanLee97/angle-llama-13b-nli) | 79.33 | 90.65 | 86.89 | 90.45 | 87.32 | 89.69 | 81.32 | **86.52** | +| [SeanLee97/angle-bert-base-uncased-nli-en-v1](https://huggingface.co/SeanLee97/angle-bert-base-uncased-nli-en-v1) | 75.09 | 85.56 | 80.66 | 86.44 | 82.47 | 85.16 | 81.23 | 82.37 | + + # Train NLI for STS Benchmark ## 1. Prepare your gpu environment From dbb819632c2f037dee3e8b4b6523ee5c0a61dfaa Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:21:04 +0800 Subject: [PATCH 05/11] add script: convert angle to sentence-transformers --- scripts/convert_to_sentence_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_to_sentence_transformer.py b/scripts/convert_to_sentence_transformer.py index 02c0772..d5ebde3 100644 --- a/scripts/convert_to_sentence_transformer.py +++ b/scripts/convert_to_sentence_transformer.py @@ -17,6 +17,7 @@ parser.add_argument('--max_length', type=int, default=512, help='Specify max length') parser.add_argument('--push_to_hub', type=int, default=0, choices=[0, 1], help='Specify push_to_hub, default 0') +parser.add_argument('--half', type=int, default=0, choices=[0, 1], help='Specify half precision, default 0') parser.add_argument('--hub_private_repo', type=int, default=1, choices=[0, 1], help='Specify hub_private_repo, default 1') parser.add_argument('--hub_model_id', type=str, default=None, @@ -27,6 +28,8 @@ word_embedding_model = models.Transformer(args.model_name_or_path, max_seq_length=args.max_length) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.pooling_strategy) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +if args.half: + model = model.half() if args.push_to_hub: - model.push_to_hub(args.hub_model_id, private=args.hub_private_repo, exist_ok=True) + model.push_to_hub(args.hub_model_id, private=args.hub_private_repo == 1, exist_ok=True) From be23a328854d726564649522508819e2884742ac Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:21:27 +0800 Subject: [PATCH 06/11] update angle-trainer path --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e19874f..2bbad51 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ tests_require=test_requirements, entry_points={ 'console_scripts': [ - 'angle-trainer = angle_emb.train_cli:main', + 'angle-trainer = angle_emb.angle_trainer:main', ], }, ) From 0c1644fa0b2e9a21f6fcf46c764f10431370865a Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:27:09 +0800 Subject: [PATCH 07/11] update changelogs --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0e0ab3a..1c52c60 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,7 @@ ๐Ÿ“ข **Train/Infer Powerful Sentence Embeddings with AnglE.** -This library is from the paper: [AnglE: Angle-optimized Text Embeddings](https://arxiv.org/abs/2309.12871) - https://arxiv.org/abs/2309.12871 -. It allows you to train state-of-the-art BERT/LLM-based sentence embeddings with just a few lines of code. AnglE is also a general sentence embedding inference framework, allowing you to infer a variety of transformer-based sentence embeddings. +This library is from the paper: [AnglE: Angle-optimized Text Embeddings](https://arxiv.org/abs/2309.12871). It allows for training state-of-the-art BERT/LLM-based sentence embeddings with just a few lines of code. AnglE is also a general sentence embedding inference framework, allowing for infering a variety of transformer-based sentence embeddings. ## โœจ Features @@ -34,6 +32,9 @@ This library is from the paper: [AnglE: Angle-optimized Text Embeddings](https:/ ## ๐Ÿ† Achievements + + https://arxiv.org/abs/2309.12871 + PyPI version @@ -291,6 +292,7 @@ You are welcome to use our code and pre-trained models. If you use our code and | ๐Ÿ“… | Description | |----|------| +| 2024 May 21 | support Espresso Sentence Embeddings | | 2024 Feb 7 | support training with only positive pairs (`DatasetFormats.C`) | | 2023 Dec 4 | Release a universal English sentence embedding model: [WhereIsAI/UAE-Large-V1](https://huggingface.co/WhereIsAI/UAE-Large-V1) | | 2023 Nov 2 | Release an English pretrained model: `SeanLee97/angle-llama-13b-nli` | From 5f5818e90591b84038dc6ea208a6deef012df3eb Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:27:49 +0800 Subject: [PATCH 08/11] fix test case --- tests/test_loadding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_loadding.py b/tests/test_loadding.py index 3d55fdb..a2ee9d8 100644 --- a/tests/test_loadding.py +++ b/tests/test_loadding.py @@ -10,8 +10,7 @@ def test_loadding(): vecs = angle.encode(['hello world', 'hi there๐Ÿ‘‹']) assert isinstance(vecs, np.ndarray) # test prompt - angle.set_prompt(prompt=Prompts.C) - vecs = angle.encode({'text': 'hello world'}) + vecs = angle.encode({'text': 'hello world'}, prompt=Prompts.C) assert isinstance(vecs, np.ndarray) vecs = angle.encode([{'text': 'hello world', 'text': 'hi there๐Ÿ‘‹'}]) assert isinstance(vecs, np.ndarray) From 88a592c17e476a9bc34822cb06ac8b20e6269ed6 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:32:16 +0800 Subject: [PATCH 09/11] fix test cases --- tests/test_loadding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_loadding.py b/tests/test_loadding.py index a2ee9d8..46e31c8 100644 --- a/tests/test_loadding.py +++ b/tests/test_loadding.py @@ -12,7 +12,7 @@ def test_loadding(): # test prompt vecs = angle.encode({'text': 'hello world'}, prompt=Prompts.C) assert isinstance(vecs, np.ndarray) - vecs = angle.encode([{'text': 'hello world', 'text': 'hi there๐Ÿ‘‹'}]) + vecs = angle.encode([{'text': 'hello world', 'text': 'hi there๐Ÿ‘‹'}], prompt=Prompts.C) assert isinstance(vecs, np.ndarray) From fba2d274ff259660b0c59b6ad792d220a2cfd19a Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:32:59 +0800 Subject: [PATCH 10/11] test python3.12 --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 980dad4..cdd1a78 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} From 7ad543d3e4d105349bf6f615c4ea2b0ac35ca09d Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 21 May 2024 16:52:22 +0800 Subject: [PATCH 11/11] upgrade docs --- docs/conf.py | 19 +++++++++++++++++-- docs/index.rst | 10 ++++++++++ docs/notes/installation.rst | 2 +- docs/notes/pretrained_models.rst | 6 +++--- .../notes/{quick_start.rst => quickstart.rst} | 0 docs/requirements.txt | 4 ++++ 6 files changed, 35 insertions(+), 6 deletions(-) rename docs/notes/{quick_start.rst => quickstart.rst} (100%) create mode 100644 docs/requirements.txt diff --git a/docs/conf.py b/docs/conf.py index 3c3957e..582b45d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,20 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [] +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx.ext.githubpages', + 'sphinx_copybutton', + 'autoapi.extension', +] + +autosummary_generate = True templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] @@ -23,5 +36,7 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'alabaster' +html_theme = 'sphinx_rtd_theme' html_static_path = ['_static'] + +autoapi_dirs = ['../angle_emb'] diff --git a/docs/index.rst b/docs/index.rst index 48f16a0..80b5dfd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -73,3 +73,13 @@ AnglE is also a general sentence embedding inference framework, allowing for inf :maxdepth: 2 :caption: Contents: + notes/quickstart.rst + notes/installation.rst + notes/pretrained_models.rst + notes/training.rst + notes/citation.rst + + +.. toctree:: + :maxdepth: 1 + :caption: APIs: diff --git a/docs/notes/installation.rst b/docs/notes/installation.rst index 3f48a5f..8923983 100644 --- a/docs/notes/installation.rst +++ b/docs/notes/installation.rst @@ -1,4 +1,4 @@ -Installation +โฌ‡๏ธ Installation ================================ diff --git a/docs/notes/pretrained_models.rst b/docs/notes/pretrained_models.rst index 0e0dc36..a4dd39c 100644 --- a/docs/notes/pretrained_models.rst +++ b/docs/notes/pretrained_models.rst @@ -1,10 +1,10 @@ ๐Ÿ›๏ธ Official Pretrained Models -=========================== +================================ BERT-based models: ------------------- +------------------------------ +------------------------------------+-------------+-------------------+--------------------------+ | ๐Ÿค— HF | Max Tokens | Pooling Strategy | Scenario | @@ -19,7 +19,7 @@ BERT-based models: LLM-based models: ------------------ +------------------------------ +------------------------------------+-----------------------------+------------------+--------------------------+------------------+---------------------------------+ | ๐Ÿค— HF (lora weight) | Backbone | Max Tokens | Prompts | Pooling Strategy | Scenario | diff --git a/docs/notes/quick_start.rst b/docs/notes/quickstart.rst similarity index 100% rename from docs/notes/quick_start.rst rename to docs/notes/quickstart.rst diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..dbcfaae --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +sphinx-rtd-theme +sphinx-copybutton +sphinx-autoapi \ No newline at end of file