From 6dc510feab41ac40fa39536f8ffc4da767d03ae0 Mon Sep 17 00:00:00 2001 From: xiyang-AADS Date: Thu, 30 May 2024 08:49:00 -0400 Subject: [PATCH] remove dup code; add unsloth support --- experiments/demo_dgx2.sh | 3 +- experiments/demo_dgx2_launch.sh | 2 +- .../accelerate_configs/deepspeed_zero2.yaml | 16 ++- recipes/accelerate_configs/deepspeed_zs2.json | 9 +- recipes/accelerate_configs/readme.md | 38 ++++++ recipes/llama3-8b/sft/config_full.yaml | 12 +- requirements.txt | 3 + scripts/run_sft.py | 85 ++++++------- src/alignment/configs.py | 4 + src/alignment/decontaminate.py | 98 -------------- src/alignment/release.py | 120 ------------------ src/alignment/unsloth.py | 107 ++++++++++++++++ 12 files changed, 219 insertions(+), 278 deletions(-) create mode 100644 recipes/accelerate_configs/readme.md delete mode 100644 src/alignment/decontaminate.py delete mode 100644 src/alignment/release.py create mode 100644 src/alignment/unsloth.py diff --git a/experiments/demo_dgx2.sh b/experiments/demo_dgx2.sh index 87c9bee9..a0a8aa1e 100644 --- a/experiments/demo_dgx2.sh +++ b/experiments/demo_dgx2.sh @@ -16,7 +16,8 @@ export HF_DATASETS_CACHE="${ROOT}/project/.cache/dataset" export HF_HOME="${ROOT}/project/.cache/" # Wandb -export WANDB_API_KEY="" +export WANDB_API_KEY="05411100e08ac02e3fcbdc821b4116cf1c066e99" +# export WANDB_API_KEY="" export WANDB_USERNAME="xi-yang5" export WANDB_PROJECT="demo_dgx2" export WANDB_LOG_MODEL="false" diff --git a/experiments/demo_dgx2_launch.sh b/experiments/demo_dgx2_launch.sh index 5f5250b4..637d4bdd 100644 --- a/experiments/demo_dgx2_launch.sh +++ b/experiments/demo_dgx2_launch.sh @@ -6,7 +6,7 @@ ROOT=$(realpath ~) CONTAINER=${ROOT}/project/singularity_containers/py2402.sig # CUDA -export CUDA_VISIBLE_DEVICES=0,1,2 +export CUDA_VISIBLE_DEVICES=0,1 # PATH DEMO_PATH=${ROOT}/project/alignment_handbook/experiments diff --git a/recipes/accelerate_configs/deepspeed_zero2.yaml b/recipes/accelerate_configs/deepspeed_zero2.yaml index fc0c80a1..d6c76abf 100644 --- a/recipes/accelerate_configs/deepspeed_zero2.yaml +++ b/recipes/accelerate_configs/deepspeed_zero2.yaml @@ -1,13 +1,15 @@ compute_environment: LOCAL_MACHINE debug: true deepspeed_config: - deepspeed_multinode_launcher: standard - offload_optimizer_device: none - offload_param_device: none - zero3_init_flag: false - zero3_save_16bit_model: false - zero_stage: 2 - mixed_precision: bf16 + deepspeed_config_file: /home/l069561/project/alignment-handbook/recipes/accelerate_configs/deepspeed_zs2.json + zero3_init_flag: true + # deepspeed_multinode_launcher: standard + # offload_optimizer_device: none + # offload_param_device: none + # zero3_init_flag: true + # zero3_save_16bit_model: false + # zero_stage: 2 + # mixed_precision: bf16 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 diff --git a/recipes/accelerate_configs/deepspeed_zs2.json b/recipes/accelerate_configs/deepspeed_zs2.json index 9597f848..dfa80708 100644 --- a/recipes/accelerate_configs/deepspeed_zs2.json +++ b/recipes/accelerate_configs/deepspeed_zs2.json @@ -1,17 +1,24 @@ { "fp16": { - "enabled": true, + "enabled": false, "loss_scale": 0, + "auto_cast": false, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, + "consecutive_hysteresis": false, "min_loss_scale": 1 }, + "bf16": { + "enabled": true + }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "weight_decay": "auto", + "betas": "auto", + "eps": "auto", "torch_adam": true, "adam_w_mode": true } diff --git a/recipes/accelerate_configs/readme.md b/recipes/accelerate_configs/readme.md new file mode 100644 index 00000000..3134d397 --- /dev/null +++ b/recipes/accelerate_configs/readme.md @@ -0,0 +1,38 @@ +## deepspeed optimizers +- DeepSpeed natively supports Adam, AdamW, OneBitAdam, Lamb, OneBitLamb, FusedLamb, FusedAdam +- see for details on how to config https://deepspeed.readthedocs.io/en/latest/optimizers.html +- +```json +{ + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 1e-3, + "weight_decay": 0.01, + "bias_correction": false, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 1000, + "cuda_aware": false, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 4.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, +} + +{ + "optimizer": { + "type": "Lamb", + "params": { + "lr": 1e-3, + "weight_decay": 0.01, + "bias_correction": false, + "max_coeff": 0.3, + "min_coeff": 0.01 + } + }, +} +``` \ No newline at end of file diff --git a/recipes/llama3-8b/sft/config_full.yaml b/recipes/llama3-8b/sft/config_full.yaml index 806b759b..1cafb813 100644 --- a/recipes/llama3-8b/sft/config_full.yaml +++ b/recipes/llama3-8b/sft/config_full.yaml @@ -17,21 +17,22 @@ preprocessing_num_workers: 8 bf16: true do_eval: true evaluation_strategy: epoch +max_grad_norm: 1.0 gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: False -hub_model_id: null -hub_strategy: every_save -learning_rate: 2.0e-05 log_level: info logging_steps: 5 logging_strategy: steps +learning_rate: 2.0e-05 +optim: galore_adamw # adamw_torch paged_adamw_32bit galore_adamw lion_32bit +weight_decay: 0.01 lr_scheduler_type: cosine -max_seq_length: 2048 +max_seq_length: 4096 max_steps: -1 num_train_epochs: 1 -output_dir: /home/l069561/project/models/fine-tuned/demo-llama-3-full-ultrachat +output_dir: /home/l069561/project/alignment_handbook/experiments/models/llama-3-full-ultrachat overwrite_output_dir: true per_device_eval_batch_size: 8 per_device_train_batch_size: 16 @@ -39,6 +40,7 @@ push_to_hub: false remove_unused_columns: true report_to: - tensorboard +- wandb save_strategy: "steps" save_steps: 100 save_total_limit: 1 diff --git a/requirements.txt b/requirements.txt index 70ed9b5e..f66031f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,6 @@ jinja2>=3.0.0 tqdm>=4.64.1 flash-attn>=2.1.0 pynvml>=11.4.0 + +# optional +galore-torch \ No newline at end of file diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 1cd458e7..d7bf4b4d 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -45,7 +45,6 @@ get_peft_config, get_quantization_config, get_tokenizer, - tokenizer_and_embedding_resize, ) from trl import SFTTrainer, setup_chat_format @@ -110,31 +109,6 @@ def main(): ) column_names = list(raw_datasets["train"].features) - ####################### - # Load pretrained model - ####################### - logger.info("*** Load pretrained model ***") - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, # attn_implementation="flash_attention_2" - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - logger.info("*** Model loaded! ***") - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, **model_kwargs - ) - ################ # Load tokenizer ################ @@ -162,6 +136,7 @@ def main(): ) model = model_args.model_name_or_path + # For ChatML we need to add special tokens and resize the embedding layer if ( "<|im_start|>" in tokenizer.chat_template @@ -173,11 +148,6 @@ def main(): model, tokenizer = setup_chat_format(model, tokenizer) model_kwargs = None - ############### - # update new tokens added to tokenizer - ############### - tokenizer_and_embedding_resize(data_args, tokenizer, model) - ##################### # Apply chat template ##################### @@ -222,30 +192,55 @@ def main(): ######################## # Initialize the Trainer ######################## - trainer = SFTTrainer( - model=model, - model_init_kwargs=model_kwargs, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - dataset_text_field="text", - max_seq_length=training_args.max_seq_length, - tokenizer=tokenizer, - packing=True, - peft_config=get_peft_config(model_args), - dataset_kwargs=training_args.dataset_kwargs, - callbacks=[GpuUtilPrintCallBack()], - ) + if model_args.use_unsloth: + from alignment.unsloth import get_unsloth_peft_model + + peft_config = get_peft_config(model_args) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, **model_kwargs + ) + model, tokenizer = setup_chat_format(model, tokenizer) + model = get_unsloth_peft_model(model, training_args.max_seq_length, peft_config) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + dataset_text_field="text", + max_seq_length=training_args.max_seq_length, + tokenizer=tokenizer, + packing=True, + dataset_kwargs=training_args.dataset_kwargs, + callbacks=[GpuUtilPrintCallBack()], + ) + else: + trainer = SFTTrainer( + model=model, + model_init_kwargs=model_kwargs, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + dataset_text_field="text", + max_seq_length=training_args.max_seq_length, + tokenizer=tokenizer, + packing=True, + peft_config=get_peft_config(model_args), + dataset_kwargs=training_args.dataset_kwargs, + callbacks=[GpuUtilPrintCallBack()], + ) ############### # Training loop ############### logger.info("*** Train ***") + checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 828b11da..466dc0e1 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -172,6 +172,10 @@ class ModelArguments: ) }, ) + use_unsloth: bool = field( + default=False, + metadata={"help": ("Whether to use unsloth to accelerate lora.")}, + ) use_peft: bool = field( default=False, metadata={"help": ("Whether to use PEFT or not for training.")}, diff --git a/src/alignment/decontaminate.py b/src/alignment/decontaminate.py deleted file mode 100644 index dfa2c38c..00000000 --- a/src/alignment/decontaminate.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List - -from datasets import load_dataset - - -# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset -HUMAN_EVAL_STRINGS_OK = [ - "return x + y", - "return len(string)", - "return n**2", - "return " ".join(strings)", -] - - -def extract_docstring(prompt: str) -> str: - if '"""' in prompt: - if prompt.count('"""') == 2: - return prompt.split('"""')[1].strip() - elif prompt.count('"""') == 4: - return prompt.split('"""')[3].strip() - else: - raise ValueError() - elif "'''" in prompt: - assert prompt.count("'''") == 2 - return prompt.split("'''")[1].strip() - else: - raise ValueError() - - -def human_eval_docstrings() -> List[str]: - ds = load_dataset("openai_humaneval", split="test") - docstrings = [extract_docstring(v["prompt"]) for v in ds] - return docstrings - - -def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]: - ds = load_dataset(dataset, split=split, name=name) - res = [sample[column].strip() for sample in ds] - # Only return non-empty strings - return [sample for sample in res if len(sample) > 0] - - -FILTER_OUT = { - "human_eval_docstrings": human_eval_docstrings(), - "human_eval_solutions": [ - s - for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") - if s not in HUMAN_EVAL_STRINGS_OK - ], -} - - -def normalize_whitespace(text: str) -> str: - return " ".join(text.split()) - - -def decontaminate_humaneval( - samples: List[Dict[str, Any]], - text_column: str = "text", - filter_out: Dict[str, List[str]] = FILTER_OUT, -) -> List[Dict[str, Any]]: - """ - filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be - filtered-out. - Return a list where each element is True if the corresponding file should be included in the dataset. - Otherwise, the element is False. - """ - output = [] - - for content in samples[text_column]: - content = normalize_whitespace(content.lower()) - matched = False - for _, substrings in filter_out.items(): - for substring in substrings: - if normalize_whitespace(substring.lower()) in content: - matched = True - break - if matched: - break - # we keep files that are not matched - output.append(not matched) - - return output diff --git a/src/alignment/release.py b/src/alignment/release.py deleted file mode 100644 index 6f42454f..00000000 --- a/src/alignment/release.py +++ /dev/null @@ -1,120 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import re - -import packaging.version - - -REPLACE_PATTERNS = { - "init": ( - re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), - '__version__ = "VERSION"\n', - ), - "setup": ( - re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), - r'\1version="VERSION",', - ), -} -REPLACE_FILES = { - "init": "src/alignment/__init__.py", - "setup": "setup.py", -} -README_FILE = "README.md" - - -def update_version_in_file(fname, version, pattern): - """Update the version in one file using a specific pattern.""" - with open(fname, "r", encoding="utf-8", newline="\n") as f: - code = f.read() - re_pattern, replace = REPLACE_PATTERNS[pattern] - replace = replace.replace("VERSION", version) - code = re_pattern.sub(replace, code) - with open(fname, "w", encoding="utf-8", newline="\n") as f: - f.write(code) - - -def global_version_update(version, patch=False): - """Update the version in all needed files.""" - for pattern, fname in REPLACE_FILES.items(): - update_version_in_file(fname, version, pattern) - - -def get_version(): - """Reads the current version in the __init__.""" - with open(REPLACE_FILES["init"], "r") as f: - code = f.read() - default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] - return packaging.version.parse(default_version) - - -def pre_release_work(patch=False): - """Do all the necessary pre-release steps.""" - # First let's get the default version: base version if we are in dev, bump minor otherwise. - default_version = get_version() - if patch and default_version.is_devrelease: - raise ValueError( - "Can't create a patch version from the dev branch, checkout a released version!" - ) - if default_version.is_devrelease: - default_version = default_version.base_version - elif patch: - default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" - else: - default_version = f"{default_version.major}.{default_version.minor + 1}.0" - - # Now let's ask nicely if that's the right one. - version = input(f"Which version are you releasing? [{default_version}]") - if len(version) == 0: - version = default_version - - print(f"Updating version to {version}.") - global_version_update(version, patch=patch) - - -def post_release_work(): - """Do all the necessary post-release steps.""" - # First let's get the current version - current_version = get_version() - dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" - current_version = current_version.base_version - - # Check with the user we got that right. - version = input(f"Which version are we developing now? [{dev_version}]") - if len(version) == 0: - version = dev_version - - print(f"Updating version to {version}.") - global_version_update(version) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--post_release", - action="store_true", - help="Whether this is pre or post release.", - ) - parser.add_argument( - "--patch", action="store_true", help="Whether or not this is a patch release." - ) - args = parser.parse_args() - if not args.post_release: - pre_release_work(patch=args.patch) - elif args.patch: - print("Nothing to do after a patch :-)") - else: - post_release_work() diff --git a/src/alignment/unsloth.py b/src/alignment/unsloth.py new file mode 100644 index 00000000..5d1f7369 --- /dev/null +++ b/src/alignment/unsloth.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2024 The Eli Lilly AADS Team. All rights reserved. +# Author: Xi Yang (xi.yang5@lilly.com) + +# this is copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/model/utils/unsloth.py + +import logging +import os +from typing import TYPE_CHECKING, Any, Dict, Optional + +import torch + + +logger = logging.getLogger(__name__) + + +def get_current_device() -> torch.device: + r""" + Gets the current available device. + """ + if torch.cuda.is_available(): + # might cause problem if we have multi-GPU and export CUDA VISIBLE DEVICE as not 0 + device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) + else: + device = "cpu" + + return torch.device(device) + + +def _get_unsloth_kwargs(config, model_name_or_path: str, model_args) -> Dict[str, Any]: + return { + "model_name": model_name_or_path, + "max_seq_length": model_args.model_max_length or 4096, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + "use_gradient_checkpointing": "unsloth", + } + + +def load_unsloth_pretrained_model(config, model_args): + r""" + Optionally loads pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs( + config, model_args.model_name_or_path, model_args + ) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning( + "Unsloth does not support model type {}.".format( + getattr(config, "model_type", None) + ) + ) + model = None + model_args.use_unsloth = False + + return model + + +def get_unsloth_peft_model(model, max_seq_length, peft_kwargs: Dict[str, Any]): + r""" + Gets the peft model for the pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": max_seq_length, + "use_gradient_checkpointing": "unsloth", + } + return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + + +def load_unsloth_peft_model(config, model_args, is_trainable: bool): + r""" + Loads peft model with unsloth. Used in both training and inference. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs( + config, model_args.adapter_name_or_path[0], model_args + ) + try: + if not is_trainable: + unsloth_kwargs["use_gradient_checkpointing"] = False + + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + raise ValueError( + "Unsloth does not support model type {}.".format( + getattr(config, "model_type", None) + ) + ) + + if not is_trainable: + FastLanguageModel.for_inference(model) + + return model