diff --git a/.gitignore b/.gitignore index 8c29b4eb..5ae78d58 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ experiments/* !experiments/demo* !experiments/README.md !experiments/util.sh +*/sang_project/* diff --git a/experiments/demo_magtrain_llm_sft.sh b/experiments/demo_magtrain_llm_sft.sh index 85208ab6..0b608326 100644 --- a/experiments/demo_magtrain_llm_sft.sh +++ b/experiments/demo_magtrain_llm_sft.sh @@ -38,8 +38,8 @@ echo $PRIMARY_PORT # manually set export WANDB_PROJECT="sang" -# TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_1.yaml -TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_2.yaml +TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_1.yaml +# TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_2.yaml DEEPSPEED_CONF=${ROOT}/recipes/accelerate_configs/deepspeed_zs2.json diff --git a/experiments/demo_magtrain_slurm.sh b/experiments/demo_magtrain_slurm.sh index 3fa035f8..269e4477 100644 --- a/experiments/demo_magtrain_slurm.sh +++ b/experiments/demo_magtrain_slurm.sh @@ -3,7 +3,7 @@ #SBATCH --job-name=llm_sft #SBATCH --mail-type=ALL #SBATCH --mail-user=xi.yang5@lilly.com -#SBATCH --nodes=4 +#SBATCH --nodes=2 #SBATCH --ntasks-per-node=1 #SBATCH --gpus-per-node=4 #SBATCH --gpus-per-task=4 diff --git a/recipes/accelerate_configs/deepspeed_zs2.json b/recipes/accelerate_configs/deepspeed_zs2.json index 215f8469..25b8761c 100644 --- a/recipes/accelerate_configs/deepspeed_zs2.json +++ b/recipes/accelerate_configs/deepspeed_zs2.json @@ -29,7 +29,7 @@ "scheduler": { "type": "WarmupDecayLR", "params": { - "warmup_min_lr": "auto", + "warmup_min_lr": 1e-8, "warmup_max_lr": "auto", "warmup_num_steps": "auto", "total_num_steps": "auto" diff --git a/recipes/accelerate_configs/deepspeed_zs3.json b/recipes/accelerate_configs/deepspeed_zs3.json new file mode 100644 index 00000000..fe859211 --- /dev/null +++ b/recipes/accelerate_configs/deepspeed_zs3.json @@ -0,0 +1,79 @@ +{ + "fp16": { + "enabled": false, + "loss_scale": 0, + "auto_cast": false, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": true + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "betas": "auto", + "eps": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": 1e-8, + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + }, + "offload_param": { + "device": "cpu" + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "reduce_bucket_size": "auto", + "overlap_comm": true, + "reduce_scatter": true, + "contiguous_gradients": true, + "round_robin_gradients": true + }, + + "aio": { + "block_size": 262144, + "queue_depth": 32, + "thread_count": 1, + "single_submit": false, + "overlap_events": true + }, + + "activation_checkpointing":{ + "partition_activations": false, + "cpu_checkpointing": false, + "contiguous_memory_optimization": true, + "number_checkpoints": null, + "synchronize_checkpoint_boundary": false, + "profile": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 20000000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/recipes/sang_project/config_full_1.yaml b/recipes/sang_project/config_full_1.yaml index 717b5e84..0f1a107c 100644 --- a/recipes/sang_project/config_full_1.yaml +++ b/recipes/sang_project/config_full_1.yaml @@ -1,5 +1,5 @@ # Model arguments -model_name_or_path: /home/l069561/project/models/gemma-2-2b +model_name_or_path: /home/l069561/project/models/Meta-Llama-3-8B #togethercomputer/StripedHyena-Hessian-7B model_revision: main torch_dtype: bfloat16 attn_implementation: flash_attention_2 @@ -7,7 +7,9 @@ attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% if messages[0]['role'] == 'system' %}{% set system_message = '### System Instruction: ' + messages[0]['content'] | trim + '' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Context: ' + message['content'] | trim + '' }}{% elif message['role'] == 'assistant' %}{{ '### Result: ' + message['content'] | trim + eos_token + '' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Result: ' }}{% endif %}" dataset_mixer: - /home/l069561/project/data/processed_data_open_sourced_xml_to_text/merged_open_sourced_xml_to_text_dataset: 1.0 + /home/l069561/project/data/test_8k_nlp: 1.0 + # HuggingFaceH4/ultrachat_200k: 1.0 + # /home/l069561/project/data/processed_data_open_sourced_xml_to_text/merged_open_sourced_xml_to_text_dataset: 1.0 # /home/l069561/project/data/sang_data_formatted: 1.0 dataset_splits: - train_sft @@ -15,8 +17,9 @@ dataset_splits: preprocessing_num_workers: 4 # SFT trainer config +trust_remote_code: true bf16: true -do_eval: true +do_eval: false # evaluation_strategy: epoch eval_strategy: epoch max_grad_norm: 1.0 @@ -36,8 +39,8 @@ max_seq_length: 8192 packing: false dataset_num_proc: 16 max_steps: -1 -num_train_epochs: 2 -output_dir: /home/l069561/project/alignment-handbook/experiments/models/sang_exp1_stage1_gemma-2-2b_full +num_train_epochs: 100 +output_dir: /home/l069561/project/alignment-handbook/experiments/models/test_deepspeed overwrite_output_dir: true per_device_eval_batch_size: 1 per_device_train_batch_size: 1 # this is per device, you need to manual calculate global batch by per device * gas * gpu * node @@ -45,10 +48,10 @@ gradient_accumulation_steps: 4 push_to_hub: false remove_unused_columns: true report_to: -- tensorboard - wandb -save_strategy: "steps" -save_steps: 2000 -save_total_limit: 10 +# - tensorboard +save_strategy: "no" +save_steps: 2500 +save_total_limit: 1 seed: 42 warmup_ratio: 0.1 diff --git a/recipes/sang_project/config_full_2.yaml b/recipes/sang_project/config_full_2.yaml index 67b604bd..094ff365 100644 --- a/recipes/sang_project/config_full_2.yaml +++ b/recipes/sang_project/config_full_2.yaml @@ -49,7 +49,7 @@ report_to: - tensorboard - wandb save_strategy: "steps" -save_steps: 1500 +save_steps: 2500 save_total_limit: 10 seed: 42 warmup_ratio: 0.1 diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 8c74f20d..91372c1a 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -36,7 +36,8 @@ GpuUtilPrintCallBack, H4ArgumentParser, ModelArguments, - ProfCallback, + PLW_apply_chat_template, + PLWTrainer, SFTConfig, apply_chat_template, get_checkpoint, @@ -107,7 +108,9 @@ def main(): ################ # Load tokenizer ################ - tokenizer = get_tokenizer(model_args, data_args, training_args) + tokenizer = get_tokenizer( + model_args, data_args, training_args, auto_set_chat_template=True + ) ####################### # Load pretrained model @@ -150,22 +153,35 @@ def main(): # Apply chat template ##################### logger.info("*** apply chat template ***") - raw_datasets = raw_datasets.map( - apply_chat_template, - fn_kwargs={ - "tokenizer": tokenizer, - "task": "sft", - "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, - }, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - desc="Applying chat template", - ) + + if training_args.use_plw: + raw_datasets = raw_datasets.map( + PLW_apply_chat_template, + fn_kwargs={ + "tokenizer": tokenizer, + "use_sample_template": training_args.use_plw_sample_template, + }, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Applying chat template", + ) + else: + raw_datasets = raw_datasets.map( + apply_chat_template, + fn_kwargs={ + "tokenizer": tokenizer, + "task": "sft", + "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, + }, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Applying chat template", + ) train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] - # this is hard coded + # this is hard coded - move to config.yaml training_args.dataset_text_field = "text" # # no need for logging samples @@ -213,26 +229,50 @@ def main(): ): model, tokenizer = setup_chat_format(model, tokenizer) - trainer = SFTTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - dataset_kwargs=training_args.dataset_kwargs, - callbacks=[GpuUtilPrintCallBack()], - ) + if training_args.use_plw: + trainer = PLWTrainer( + prompt_loss_weight=training_args.prompt_loss_weight, + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + dataset_kwargs=training_args.dataset_kwargs, + # callbacks=[GpuUtilPrintCallBack()], + ) + else: + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + dataset_kwargs=training_args.dataset_kwargs, + # callbacks=[GpuUtilPrintCallBack()], + ) else: - trainer = SFTTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - peft_config=get_peft_config(model_args), - dataset_kwargs=training_args.dataset_kwargs, - callbacks=[GpuUtilPrintCallBack()], - ) + if training_args.use_plw: + trainer = PLWTrainer( + prompt_loss_weight=training_args.prompt_loss_weight, + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + dataset_kwargs=training_args.dataset_kwargs, + # callbacks=[GpuUtilPrintCallBack()], + ) + else: + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=get_peft_config(model_args), + dataset_kwargs=training_args.dataset_kwargs, + # callbacks=[GpuUtilPrintCallBack()], + ) ############### # Training loop @@ -290,13 +330,11 @@ def main(): # logger.info("Pushing to hub...") # trainer.push_to_hub(**kwargs) - torch.cuda.memory._dump_snapshot( - Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle" - ) + # torch.cuda.memory._dump_snapshot(Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle") # prof.close() logger.info("*** Training complete ***") if __name__ == "__main__": - torch.cuda.memory._record_memory_history() + # torch.cuda.memory._record_memory_history() main() diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index 5c92315c..f08974cd 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -9,6 +9,7 @@ ) from .data import apply_chat_template, get_datasets from .model_utils import ( + add_new_special_token, get_checkpoint, get_kbit_device_map, get_peft_config, @@ -17,13 +18,9 @@ is_adapter_model, tokenizer_and_embedding_resize, ) +from .plw_trainer import PLW_apply_chat_template, PLWTrainer from .simpo_trainer import SimPOTrainer -from .utils import ( - GpuUtilPrintCallBack, - ProfCallback, - print_gpu_utilization, - print_summary, -) +from .utils import GpuUtilPrintCallBack, ProfCallback __all__ = [ @@ -34,11 +31,13 @@ "SFTConfig", "apply_chat_template", "get_datasets", - "decontaminate_humaneval", "get_checkpoint", "get_kbit_device_map", "get_peft_config", "get_quantization_config", "get_tokenizer", "is_adapter_model", + "PLW_apply_chat_template", + "PLWTrainer", + "SimPOTrainer", ] diff --git a/src/alignment/configs.py b/src/alignment/configs.py index a0e59270..1dcf5d1d 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -294,10 +294,9 @@ class SFTConfig(trl.SFTConfig): "help": ("Whether to log and evaluate the first global_step or not.") }, ) - # max_seq_length: Optional[int] = field( - # default=None, - # ) - # packing: Optional[bool] = field(default=False) + prompt_loss_weight: float = field(default=0.1) + use_plw: bool = field(default=False) + use_plw_sample_template: bool = field(default=False) @dataclass diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index e91f383c..86fa7fc3 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path import torch from transformers import ( @@ -27,12 +26,13 @@ from accelerate import Accelerator from huggingface_hub import list_repo_files -from huggingface_hub.utils._errors import RepositoryNotFoundError + +# from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from peft import LoraConfig, PeftConfig from .configs import DataArguments, ModelArguments -from .data import DEFAULT_CHAT_TEMPLATE, DEFAULT_PAD_TOKEN +from .data import DEFAULT_CHAT_TEMPLATE def get_current_device() -> int: @@ -116,6 +116,21 @@ def tokenizer_and_embedding_resize( output_embeddings_data[-num_new_tokens:] = output_embeddings_avg +def add_new_special_token(new_special_token, tokenizer, model): + for k, v in new_special_token.items(): + # get exsiting special token + stk = tokenizer.special_tokens_map.get(k, None) + if stk: + idx = tokenizer.convert_tokens_to_ids(stk) + tk_emb = model.get_input_embeddings().weight.data[idx] + else: + tk_emb = model.get_input_embeddings().weight.data.mean(dim=0, keepdim=False) + + tokenizer.add_special_tokens({k: v}) + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) + model.get_input_embeddings().weight.data[-1] = tk_emb + + def get_tokenizer( model_args: ModelArguments, data_args: DataArguments, @@ -132,8 +147,23 @@ def get_tokenizer( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, ) + if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id + if "llama" in tokenizer.name_or_path.lower(): + llama_version = tokenizer.name_or_path.split("/")[-1].split("-")[-2] + if llama_version == "3.2": + pad_token = "<|finetune_right_pad_id|>" + elif llama_version == "3": + pad_token = "<|reserved_special_token_0|>" + else: + raise RuntimeError( + f"check {tokenizer.name_or_path} to make sure we have a version like Meta-Llama-3-8B or Meta-Llama-3.2-3B" + ) + tokenizer.pad_token = pad_token + tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + else: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id if data_args.truncation_side is not None: tokenizer.truncation_side = data_args.truncation_side @@ -143,16 +173,19 @@ def get_tokenizer( if train_args: tokenizer.model_max_length = train_args.max_seq_length - if tokenizer.model_max_length > 100_000: - tokenizer.model_max_length = 2048 + if tokenizer.model_max_length > 128000: + tokenizer.model_max_length = 4096 if data_args.chat_template is not None: tokenizer.chat_template = data_args.chat_template - elif auto_set_chat_template and tokenizer.get_chat_template() is None: + elif auto_set_chat_template and tokenizer.chat_template is None: tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE tokenizer.pad_to_multiple_of = 8 + # training is ok for right / left but for batch inference, we need to set padding side as left + # tokenizer.padding_side = "left" + return tokenizer @@ -196,9 +229,11 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: try: # Try first if model on a Hub repo repo_files = list_repo_files(model_name_or_path, revision=revision) - except (HFValidationError, RepositoryNotFoundError): - # If not, check local repo + except: repo_files = os.listdir(model_name_or_path) + # except (HFValidationError, RepositoryNotFoundError): + # # If not, check local repo + # repo_files = os.listdir(model_name_or_path) return ( "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files ) diff --git a/src/alignment/plw_trainer.py b/src/alignment/plw_trainer.py new file mode 100644 index 00000000..6c00678b --- /dev/null +++ b/src/alignment/plw_trainer.py @@ -0,0 +1,174 @@ +import warnings + +import datasets +import torch +from torch.nn import CrossEntropyLoss + +from trl import SFTTrainer + + +def PLW_sample_chat_template(): + template = "{% if messages[0]['role'] == 'system' %}{% set system_message = bos_token + '### System Instruction: ' + messages[0]['content'] | trim + '' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ system_message }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Context: ' + message['content'] | trim + '\n' }}{% elif message['role'] == 'assistant' %}{{ '### Result: ' + message['content'] | trim + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Result: ' }}{% endif %}" + + return template + + +def PLW_apply_chat_template(example, tokenizer=None, use_sample_template=False): + messages = example["messages"] + prompts = list(filter(lambda x: x["role"] != "assistant", messages)) + labels = list(filter(lambda x: x["role"] == "assistant", messages)) + + if use_sample_template: + tokenizer.chat_template = PLW_sample_chat_template() + + example["prompt"] = tokenizer.apply_chat_template( + prompts, tokenize=False, add_generation_prompt=False + ) + example["completion"] = tokenizer.apply_chat_template( + labels, tokenize=False, add_generation_prompt=False + ) + return example + + +class PLWTrainer(SFTTrainer): + def __init__(self, *args, prompt_loss_weight=0.1, **kwargs): + super().__init__(*args, **kwargs) + self.plw = prompt_loss_weight + # need to add prompt_mask and completion_mask to dataset generation + + def compute_loss(self, model, inputs, return_outputs=False): + # get outputs without computing loss (by not passing in labels) + outputs = model( + input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] + ) + logits = outputs.get("logits") + labels = inputs.pop("labels") + + weights = self.plw * inputs["prompt_mask"] + inputs["completion_mask"] + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_weights = weights[..., 1:].contiguous() + + shift_labels = shift_labels.to(shift_logits.device) + shift_weights = shift_weights.to(shift_logits.device) + + # Compute per-token losses + loss_fct = CrossEntropyLoss(reduction="none") + token_losses = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + # Compute weighted average of losses + loss = token_losses @ shift_weights.view(-1) / shift_weights.sum() + return (loss, outputs) if return_outputs else loss + + def _prepare_dataset( + self, + dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + remove_unused_columns=True, + append_concat_token=True, + add_special_tokens=True, + skip_prepare_dataset=False, + ): + if dataset is None: + raise ValueError("The dataset should not be None") + + if skip_prepare_dataset: + return dataset + + # If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is + # a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset + column_names = ( + dataset.column_names + if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) + else None + ) + if column_names and "input_ids" in column_names: + if formatting_func is not None: + warnings.warn( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored." + ) + + return dataset + + # check if torch dataset / dataloader and do nothing + # see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check + if isinstance( + dataset, + (torch.utils.data.IterableDataset, torch.utils.data.Dataset), + ) and not isinstance(dataset, datasets.IterableDataset): + return dataset + + self._dataset_sanity_checked = False + + def tokenize(element): + prompts = element["prompt"] + labels = element["completion"] + + input_ids = [] + attention_mask = [] + prompt_masks = [] + completion_masks = [] + + for prmp, lab in zip(prompts, labels): + p = tokenizer( + prmp, + add_special_tokens=False, + truncation=False, + padding=False, + return_overflowing_tokens=False, + return_length=True, + ) + + l = tokenizer( + lab, + add_special_tokens=False, + truncation=False, + padding=False, + return_overflowing_tokens=False, + return_length=True, + ) + + p_len = p["length"][0] + l_len = l["length"][0] + gap_len = p_len + l_len - tokenizer.model_max_length + + new_input_ids = p["input_ids"] + l["input_ids"] + new_attn_mask = p["attention_mask"] + l["attention_mask"] + prompt_mask = p["attention_mask"] + [0] * l_len + completion_mask = [0] * p_len + l["attention_mask"] + + # truncate from left side + if gap_len > 0: + new_input_ids = new_input_ids[gap_len:] + new_attn_mask = new_attn_mask[gap_len:] + prompt_mask = prompt_mask[gap_len:] + completion_mask = completion_mask[gap_len:] + + input_ids.append(new_input_ids) + attention_mask.append(new_attn_mask) + prompt_masks.append(prompt_mask) + completion_masks.append(completion_mask) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prompt_mask": prompt_masks, + "completion_mask": completion_masks, + } + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=self.dataset_num_proc, + batch_size=self.dataset_batch_size, + ) diff --git a/src/alignment/utils.py b/src/alignment/utils.py index 06a1eefa..ccd766b3 100644 --- a/src/alignment/utils.py +++ b/src/alignment/utils.py @@ -5,26 +5,33 @@ from pynvml import * -def print_gpu_utilization(): +def get_gpu_utilization(): nvmlInit() handle = nvmlDeviceGetHandleByIndex(0) info = nvmlDeviceGetMemoryInfo(handle) - print(f"GPU memory occupied: {info.used//1024**2} MB.") + return f"GPU memory occupied: {info.used//1024**2} MB." def print_summary(result): print(f"Time: {result.metrics['train_runtime']:.2f}") print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}") - print_gpu_utilization() class GpuUtilPrintCallBack(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): - if state.is_local_process_zero: - print(datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")) - print(logs) - print_gpu_utilization() - # print_summary(args) + if state.is_local_process_zero and (state.global_step + 1) % 100 == 0: + print( + "[", datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S"), "]\t", logs + ) + + if (state.global_step + 1) % 500 == 0: + print(get_gpu_utilization()) + + def on_train_begin(self, args, state, control, logs=None, **kwargs): + print(get_gpu_utilization()) + + def on_epoch_end(self, args, state, control, logs=None, **kwargs): + print(get_gpu_utilization()) class ProfCallback(TrainerCallback):