-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f5278d
commit 41782a2
Showing
9 changed files
with
136 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
## Setup | ||
|
||
* Initialize the submodules | ||
```bash | ||
git submodule init | ||
git submodule update | ||
``` | ||
|
||
* Install the requirements | ||
```bash | ||
conda create -n venv | ||
conda activate venv | ||
pip install -r requirements.txt | ||
``` | ||
|
||
* Finetune with QLoRA quantization | ||
```bash | ||
python llama2_qlora.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
accelerate @ git+https://github.com/huggingface/accelerate.git@cdb001ca5f9be875034ddd0aa86a542c182782fe | ||
bitsandbytes==0.41.1 | ||
black==23.9.1 | ||
datasets==2.14.5 | ||
einops==0.6.1 | ||
evaluate==0.4.0 | ||
fairscale==0.4.13 | ||
fire==0.5.0 | ||
jsonlines==4.0.0 | ||
llama @ git+https://github.com/facebookresearch/llama@9f0e393991b45d320f5b4a287eaaeb8a7d2e6f8e | ||
openai==0.28.0 | ||
pandas==2.1.0 | ||
peft @ git+https://github.com/huggingface/peft.git@0fa63fb4a21bf88777b2469892b76a6e096753e8 | ||
torch==2.0.1 | ||
transformers @ git+https://github.com/huggingface/transformers.git@95b374952dc27d8511541d6f5a4e22c9ec11fb24 | ||
wandb==0.15.10 | ||
scipy==1.11.2 | ||
tiktoken==0.5.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from transformers import Trainer, TrainingArguments | ||
from transformers.utils import logging | ||
from transformers.trainer_utils import EvalLoopOutput | ||
|
||
from evaluator.agieval import wrapper as wrapper_agieval | ||
from eval_args import EvaluationArguments | ||
# from evaluator.agentbench import wrapper as wrapper_agentbench | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
class MandrillTrainer(Trainer): | ||
""" | ||
avoid setting label to None: https://github.com/huggingface/transformers/blob/5a4f340df74b42b594aedf60199eea95cdb9bed0/src/transformers/trainer.py#L2703C26-L2703C26 | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
self.model_id = kwargs.pop('model_id') | ||
self.eval_args = kwargs.pop('eval_args') | ||
self.hf_api_token = kwargs.pop('hf_api_token') | ||
super().__init__(*args, **kwargs) | ||
|
||
def compute_loss(self, model, inputs): | ||
outputs = model(**inputs) | ||
return outputs.loss | ||
|
||
def evaluation_loop(self, dataloader, description, prediction_loss_only=False, **kwargs) -> EvalLoopOutput: | ||
''' | ||
https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/trainer_utils.py#L147 | ||
''' | ||
model = self._wrap_model(self.model, training=False, dataloader=dataloader) | ||
model.eval() | ||
|
||
if 'agieval' in self.eval_args.tasks_list: | ||
logger.info(f"***** Runnning Evaluation on AGIEval *****") | ||
wrapper_agieval.evaluate(model=model, model_id=self.model_id, hf_api_token=self.hf_api_token, | ||
system_prompt=self.eval_args.system_prompt, temperature=self.eval_args.temperature, | ||
max_new_tokens=self.eval_args.max_new_tokens, top_p=self.eval_args.top_p, | ||
batch_size=self.args.per_device_eval_batch_size,) | ||
if 'agentbench' in self.eval_args.tasks_list: | ||
logger.info(f"***** Runnning Evaluation on AgentBench *****") | ||
wrapper_agentbench.evaluate(model=model, model_id=self.model_id, hf_api_token=self.hf_api_token, | ||
system_prompt=self.eval_args.system_prompt, temperature=self.eval_args.temperature, | ||
max_new_tokens=self.eval_args.max_new_tokens, top_p=self.eval_args.top_p, | ||
batch_size=self.args.per_device_eval_batch_size,) | ||
return EvalLoopOutput(predictions=None, label_ids=None, metrics={'fake_metric': 0.0}, num_samples=0) |
This file was deleted.
Oops, something went wrong.