Skip to content

Commit

Permalink
add format hook
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyang-aads-lilly committed May 15, 2024
1 parent a5f0909 commit 23e5d1b
Show file tree
Hide file tree
Showing 17 changed files with 385 additions and 115 deletions.
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
# we do not set python version so it will use default

- id: black-jupyter
# # It is recommended to specify the latest version of Python
# # supported by your project here, or alternatively use
# # pre-commit's default_language_version, see
# # https://pre-commit.com/#top_level-default_language_version
# language_version: python3.11
19 changes: 14 additions & 5 deletions scripts/run_cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,36 @@ def main():

if train_dataset is None:
raise ValueError(
"Training set must be included (so make sure that your dataset has a split with" " 'train' in the name)."
"Training set must be included (so make sure that your dataset has a split with"
" 'train' in the name)."
)

if training_args.do_eval and eval_dataset is None:
raise ValueError("'--do_eval' enabled so make sure that your dataset has a split with 'test' in the name.")
raise ValueError(
"'--do_eval' enabled so make sure that your dataset has a split with 'test' in the name."
)

################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, data_args, auto_set_chat_template=False)

with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
with training_args.main_process_first(
desc="Log a few random samples from the processed training set"
):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
logger.info(
f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}"
)

#######################
# Load pretrained model
#######################
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

Expand Down
65 changes: 50 additions & 15 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
from pathlib import Path


p = Path(__file__).parent.parent / "src"
sys.path.append(p.as_posix())

Expand All @@ -30,6 +31,7 @@
DPOConfig,
H4ArgumentParser,
ModelArguments,
ProfCallback,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
Expand All @@ -39,7 +41,6 @@
get_quantization_config,
get_tokenizer,
is_adapter_model,
ProfCallback,
)
from peft import PeftConfig, PeftModel
from trl import DPOTrainer
Expand Down Expand Up @@ -86,7 +87,14 @@ def main():
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
columns_to_keep=[
"messages",
"chosen",
"rejected",
"prompt",
"completion",
"label",
],
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
Expand All @@ -96,7 +104,9 @@ def main():
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
data_args.truncation_side = (
"left" # Truncate from left to ensure we don't lose labels in final turn
)
tokenizer = get_tokenizer(model_args, data_args)

#####################
Expand Down Expand Up @@ -134,17 +144,29 @@ def main():
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
{
"text_prompt": "prompt",
"text_chosen": "chosen",
"text_rejected": "rejected",
}
)

# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
logger.info(
f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}"
)
logger.info(
f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}"
)
logger.info(
f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}"
)

torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

Expand All @@ -164,14 +186,18 @@ def main():
# Note: to run QLoRA, you will need to merge the base model separately as the merged model in 16bit
logger.info(f"Merging PEFT adapters for {model_args.model_name_or_path}")

peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
peft_config = PeftConfig.from_pretrained(
model_args.model_name_or_path, revision=model_args.model_revision
)
model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
device_map=(
get_kbit_device_map() if quantization_config is not None else None
),
quantization_config=quantization_config,
)
base_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -196,9 +222,16 @@ def main():
# PYTORCH profiler
##################
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=3, wait=1, warmup=1, active=2, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name=training_args.logging_dir),
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
skip_first=3, wait=1, warmup=1, active=2, repeat=2
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
dir_name=training_args.logging_dir
),
profile_memory=True,
with_stack=True,
record_shapes=True,
Expand All @@ -223,7 +256,7 @@ def main():
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
loss_type=training_args.loss_type,
callbacks=[], #[ProfCallback(prof)],
callbacks=[], # [ProfCallback(prof)],
)

###############
Expand Down Expand Up @@ -277,7 +310,9 @@ def main():
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)

torch.cuda.memory._dump_snapshot(Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle")
torch.cuda.memory._dump_snapshot(
Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle"
)
logger.info("*** Training complete! ***")


Expand Down
20 changes: 15 additions & 5 deletions scripts/run_orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,15 @@ def main():
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
data_args.truncation_side = (
"left" # Truncate from left to ensure we don't lose labels in final turn
)
tokenizer = get_tokenizer(model_args, data_args)

torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

Expand Down Expand Up @@ -195,9 +199,15 @@ def filter_fn(sample: Dict[str, Any]) -> Dict[str, Any]:

# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
logger.info(
f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}"
)
logger.info(
f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}"
)
logger.info(
f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}"
)

##########################
# Instantiate ORPO trainer
Expand Down
21 changes: 17 additions & 4 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ def main():
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
columns_to_keep=[
"messages",
"chosen",
"rejected",
"prompt",
"completion",
"label",
],
)
logger.info(
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
Expand Down Expand Up @@ -124,7 +131,9 @@ def main():
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)

################
# Load tokenizer
Expand Down Expand Up @@ -158,7 +167,9 @@ def main():
"<|im_start|>" in tokenizer.chat_template
and "gemma-tokenizer-chatml" not in tokenizer.name_or_path
):
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
model, tokenizer = setup_chat_format(model, tokenizer)
model_kwargs = None

Expand Down Expand Up @@ -277,7 +288,9 @@ def main():
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)

torch.cuda.memory._dump_snapshot(Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle")
torch.cuda.memory._dump_snapshot(
Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle"
)
# prof.close()
logger.info("*** Training complete ***")

Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
deps = {
b: a
for a, b in (
re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0]
for x in _deps
)
}


def deps_list(*pkgs):
Expand Down
15 changes: 13 additions & 2 deletions src/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
__version__ = "0.3.0"

from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
from .configs import (
DataArguments,
DPOConfig,
H4ArgumentParser,
ModelArguments,
SFTConfig,
)
from .data import apply_chat_template, get_datasets
from .model_utils import (
get_checkpoint,
Expand All @@ -11,4 +17,9 @@
is_adapter_model,
tokenizer_and_embedding_resize,
)
from .utils import GpuUtilPrintCallBack, ProfCallback, print_gpu_utilization, print_summary
from .utils import (
GpuUtilPrintCallBack,
ProfCallback,
print_gpu_utilization,
print_summary,
)
Loading

0 comments on commit 23e5d1b

Please sign in to comment.