From a6484c0b199c1a3c8ab4e577e869f8ae2071d10a Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 29 Nov 2024 01:45:52 +0100 Subject: [PATCH 01/24] capture init parameters in training_args --- src/transformers/training_args.py | 53 ++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0653c8a2cb7..43d2d3a1c50 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -218,7 +218,6 @@ def _convert_str_dict(passed_value: dict): return passed_value -# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @dataclass class TrainingArguments: """ @@ -1541,6 +1540,19 @@ class TrainingArguments: }, ) + def __new__(self, *args, **kwargs): + # catch and save only the parameters that the user passed + self.__training_args_params__ = {} + param_names = list(self.__dataclass_fields__.keys()) + + for i in range(len(args)): + self.__training_args_params__[param_names[i]] = serialize_parameter(param_names[i], args[i]) + + for k, v in kwargs.items(): + self.__training_args_params__[k] = serialize_parameter(k, v) + + return super().__new__(self) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1562,6 +1574,8 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) + # set logging_dir in __training_args_params__ + self.__training_args_params__["logging_dir"] = self.logging_dir if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN @@ -2524,15 +2538,7 @@ def to_dict(self): d = {field.name: getattr(self, field.name) for field in fields(self) if field.init} for k, v in d.items(): - if isinstance(v, Enum): - d[k] = v.value - if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): - d[k] = [x.value for x in v] - if k.endswith("_token"): - d[k] = f"<{k.upper()}>" - # Handle the accelerator_config if passed - if is_accelerate_available() and isinstance(v, AcceleratorConfig): - d[k] = v.to_dict() + d[k] = serialize_parameter(k, v) self._dict_torch_dtype_to_str(d) return d @@ -2541,7 +2547,14 @@ def to_json_string(self): """ Serializes this instance to a JSON string. """ - return json.dumps(self.to_dict(), indent=2) + return json.dumps(self.__training_args_dict__, indent=2) + + def to_json_file(self, json_file_path: str): + """ + Save this instance's parameters to a json file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) def to_sanitized_dict(self) -> Dict[str, Any]: """ @@ -3099,3 +3112,21 @@ class ParallelMode(Enum): SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" TPU = "tpu" + + +def serialize_parameter(k, v): + if k == "torch_dtype" and not isinstance(v, str): + return str(k).split(".")[1] + if isinstance(v, dict): + return {key: serialize_parameter(key, value) for key, value in v.items()} + if isinstance(v, Enum): + return v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + l = [x.value for x in v] + return l + if k.endswith("_token"): + return f"<{k.upper()}>" + # Handle the accelerator_config if passed + if is_accelerate_available() and isinstance(v, AcceleratorConfig): + return v.to_dict() + return v From 780d3e7bb40760ef563544ee8c5b080be76f4292 Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 29 Nov 2024 02:08:48 +0100 Subject: [PATCH 02/24] update relevant attributes --- docs/source/en/deepspeed.md | 2 +- docs/source/ja/main_classes/deepspeed.md | 2 +- docs/source/ko/deepspeed.md | 2 +- docs/source/zh/main_classes/deepspeed.md | 2 +- examples/legacy/question-answering/run_squad.py | 4 ++-- examples/legacy/run_swag.py | 4 ++-- .../bert-loses-patience/run_glue_with_pabee.py | 4 ++-- examples/research_projects/deebert/run_glue_deebert.py | 4 ++-- .../distillation/run_squad_w_distillation.py | 4 ++-- examples/research_projects/mm-imdb/run_mmimdb.py | 4 ++-- .../research_projects/movement-pruning/masked_run_glue.py | 4 ++-- .../movement-pruning/masked_run_squad.py | 4 ++-- .../research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md | 2 +- src/transformers/trainer.py | 8 ++++---- tests/deepspeed/test_deepspeed.py | 2 +- tests/trainer/test_trainer.py | 2 +- 16 files changed, 27 insertions(+), 27 deletions(-) diff --git a/docs/source/en/deepspeed.md b/docs/source/en/deepspeed.md index 7f7995c4664..926f7f53a89 100644 --- a/docs/source/en/deepspeed.md +++ b/docs/source/en/deepspeed.md @@ -895,7 +895,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/ -rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model -rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json -rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json --rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin +-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json -rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py* ``` diff --git a/docs/source/ja/main_classes/deepspeed.md b/docs/source/ja/main_classes/deepspeed.md index 4406ce4a34e..fc7b7f89112 100644 --- a/docs/source/ja/main_classes/deepspeed.md +++ b/docs/source/ja/main_classes/deepspeed.md @@ -1705,7 +1705,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/ -rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model -rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json -rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json --rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin +-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json -rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py* ``` diff --git a/docs/source/ko/deepspeed.md b/docs/source/ko/deepspeed.md index 9945e298b77..db881f80d46 100644 --- a/docs/source/ko/deepspeed.md +++ b/docs/source/ko/deepspeed.md @@ -894,7 +894,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/ -rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model -rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json -rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json --rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin +-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json -rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py* ``` diff --git a/docs/source/zh/main_classes/deepspeed.md b/docs/source/zh/main_classes/deepspeed.md index 75a0a13df75..06660a03ca7 100644 --- a/docs/source/zh/main_classes/deepspeed.md +++ b/docs/source/zh/main_classes/deepspeed.md @@ -1589,7 +1589,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/ -rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model -rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json -rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json --rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin +-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json -rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py* ``` diff --git a/examples/legacy/question-answering/run_squad.py b/examples/legacy/question-answering/run_squad.py index f5a827c15ac..4d1d360c7f1 100644 --- a/examples/legacy/question-answering/run_squad.py +++ b/examples/legacy/question-answering/run_squad.py @@ -245,7 +245,7 @@ def train(args, train_dataset, model, tokenizer): model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -792,7 +792,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True) diff --git a/examples/legacy/run_swag.py b/examples/legacy/run_swag.py index dbf712a71ff..c5584770419 100755 --- a/examples/legacy/run_swag.py +++ b/examples/legacy/run_swag.py @@ -396,7 +396,7 @@ def train(args, train_dataset, model, tokenizer): ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_vocabulary(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: @@ -678,7 +678,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = AutoModelForMultipleChoice.from_pretrained(args.output_dir) diff --git a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py index d1ee5ddde3c..a370400fe7d 100755 --- a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py +++ b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py @@ -240,7 +240,7 @@ def train(args, train_dataset, model, tokenizer): model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -712,7 +712,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) diff --git a/examples/research_projects/deebert/run_glue_deebert.py b/examples/research_projects/deebert/run_glue_deebert.py index 6ca28ab5bc0..c3760067725 100644 --- a/examples/research_projects/deebert/run_glue_deebert.py +++ b/examples/research_projects/deebert/run_glue_deebert.py @@ -221,7 +221,7 @@ def train(args, train_dataset, model, tokenizer, train_highway=False): model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: @@ -672,7 +672,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) diff --git a/examples/research_projects/distillation/run_squad_w_distillation.py b/examples/research_projects/distillation/run_squad_w_distillation.py index a1150f6b437..f950eb5ded7 100644 --- a/examples/research_projects/distillation/run_squad_w_distillation.py +++ b/examples/research_projects/distillation/run_squad_w_distillation.py @@ -285,7 +285,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -836,7 +836,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) diff --git a/examples/research_projects/mm-imdb/run_mmimdb.py b/examples/research_projects/mm-imdb/run_mmimdb.py index 686691e0b9c..0ba2e3ec308 100644 --- a/examples/research_projects/mm-imdb/run_mmimdb.py +++ b/examples/research_projects/mm-imdb/run_mmimdb.py @@ -203,7 +203,7 @@ def train(args, train_dataset, model, tokenizer, criterion): model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: @@ -540,7 +540,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = MMBTForClassification(config, transformer, img_encoder) diff --git a/examples/research_projects/movement-pruning/masked_run_glue.py b/examples/research_projects/movement-pruning/masked_run_glue.py index 4ddb4248357..86743570a97 100644 --- a/examples/research_projects/movement-pruning/masked_run_glue.py +++ b/examples/research_projects/movement-pruning/masked_run_glue.py @@ -392,7 +392,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -927,7 +927,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) diff --git a/examples/research_projects/movement-pruning/masked_run_squad.py b/examples/research_projects/movement-pruning/masked_run_squad.py index 7b1c2b32209..9499fa8ea3e 100644 --- a/examples/research_projects/movement-pruning/masked_run_squad.py +++ b/examples/research_projects/movement-pruning/masked_run_squad.py @@ -413,7 +413,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) - torch.save(args, os.path.join(output_dir, "training_args.bin")) + args.to_json_file(os.path.join(output_dir, "training_args.json")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -1094,7 +1094,7 @@ def main(): tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model - torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + args.to_json_file(os.path.join(args.output_dir, "training_args.json")) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) # , force_download=True) diff --git a/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md b/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md index 7a580a36132..c8ec0cc9401 100644 --- a/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md +++ b/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md @@ -122,7 +122,7 @@ Having finished the training you should find the following files/folders under t - `special_tokens_map.json` - the special token map of the tokenizer - `tokenizer_config.json` - the parameters of the tokenizer - `vocab.json` - the vocabulary of the tokenizer -- `checkpoint-{...}/` - the saved checkpoints saved during training. Each checkpoint should contain the files: `config.json`, `optimizer.pt`, `pytorch_model.bin`, `scheduler.pt`, `training_args.bin`. The files `config.json` and `pytorch_model.bin` define your model. +- `checkpoint-{...}/` - the saved checkpoints saved during training. Each checkpoint should contain the files: `config.json`, `optimizer.pt`, `pytorch_model.bin`, `scheduler.pt`, `training_args.json`. The files `config.json` and `pytorch_model.bin` define your model. If you are happy with your training results it is time to upload your model! Download the following files to your local computer: **`preprocessor_config.json`, `special_tokens_map.json`, `tokenizer_config.json`, `vocab.json`, `config.json`, `pytorch_model.bin`**. Those files fully define a XLSR-Wav2Vec2 model checkpoint. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ed45624983a..937ebc76913 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -301,7 +301,7 @@ def safe_globals(): # Name of the files used for checkpointing -TRAINING_ARGS_NAME = "training_args.bin" +TRAINING_ARGS_NAME = "training_args.json" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" @@ -3814,7 +3814,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -3911,7 +3911,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): self.processing_class.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) def store_flos(self): # Storing the number of floating-point operations that went into the model @@ -4621,7 +4621,7 @@ def _push_from_checkpoint(self, checkpoint_folder): if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) # Same for the training arguments - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) if self.args.save_strategy == SaveStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 8eaa00bc768..59d3d505643 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -740,7 +740,7 @@ def test_gradient_accumulation(self, stage, dtype): def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): # adapted from TrainerIntegrationCommon.check_saved_checkpoints - file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"] + file_list = [SAFE_WEIGHTS_NAME, "training_args.json", "trainer_state.json", "config.json"] if stage == ZERO2: ds_file_list = ["mp_rank_00_model_states.pt"] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5658372fa71..104dde4e68f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -572,7 +572,7 @@ def get_regression_trainer( class TrainerIntegrationCommon: def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME - file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] + file_list = [weights_file, "training_args.json", "optimizer.pt", "scheduler.pt", "trainer_state.json"] if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): From 66c655b8c6996a252b67dae12eef7641a81cf50d Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 29 Nov 2024 02:45:55 +0100 Subject: [PATCH 03/24] attribute calling for training args --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 43d2d3a1c50..5a4ba5cbc36 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2547,7 +2547,7 @@ def to_json_string(self): """ Serializes this instance to a JSON string. """ - return json.dumps(self.__training_args_dict__, indent=2) + return json.dumps(self.__training_args_params__, indent=2) def to_json_file(self, json_file_path: str): """ From 6d3a186ed279e17422eab4b697f47a7d25ca73e5 Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 29 Nov 2024 04:09:06 +0100 Subject: [PATCH 04/24] add class attribute to load the training_args from a local file --- src/transformers/training_args.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5a4ba5cbc36..e163d4966c9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -22,7 +22,7 @@ from datetime import timedelta from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, TypeVar, Union from huggingface_hub import get_full_repo_name from packaging import version @@ -67,6 +67,8 @@ log_levels = logging.get_log_levels_dict().copy() trainer_log_levels = dict(**log_levels, passive=-1) +T = TypeVar("T", bound="TrainingArguments") + if is_torch_available(): import torch import torch.distributed as dist @@ -2556,6 +2558,15 @@ def to_json_file(self, json_file_path: str): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) + @classmethod + def from_json_file(cls: Type[T], json_file_path: str) -> T: + """ + Constructs an instance from a json file + """ + with open(json_file_path, "r", encoding="utf-8") as reader: + params = json.load(reader) + return cls(**params) + def to_sanitized_dict(self) -> Dict[str, Any]: """ Sanitized serialization to use with TensorBoard’s hparams From 5acc885a38859bf257bfa9333e57b593a5d6d1be Mon Sep 17 00:00:00 2001 From: Hafedh Hichri <70411813+not-lain@users.noreply.github.com> Date: Fri, 29 Nov 2024 18:07:35 +0100 Subject: [PATCH 05/24] Update src/transformers/training_args.py --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b87d4073ab3..78f03ac5005 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -3132,7 +3132,7 @@ class ParallelMode(Enum): def serialize_parameter(k, v): if k == "torch_dtype" and not isinstance(v, str): - return str(k).split(".")[1] + return str(v).split(".")[1] if isinstance(v, dict): return {key: serialize_parameter(key, value) for key, value in v.items()} if isinstance(v, Enum): From 79d2a3a4e4379559020482981f361471e786dbfc Mon Sep 17 00:00:00 2001 From: Lain Date: Wed, 11 Dec 2024 22:54:51 +0100 Subject: [PATCH 06/24] ensure backward compatibility --- src/transformers/trainer.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e27eacfae35..097be54355a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -302,6 +302,7 @@ def safe_globals(): # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.json" +DEPRECATED_ARGS_NAME = "trainer_state.json" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" @@ -309,6 +310,12 @@ def safe_globals(): SCALER_NAME = "scaler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" +# Safe serialization check +safe_serialize = os.environ.get("TRAINER_SAFE_SERIALIZE") + +if safe_serialize or __version__.split(".")[0] >= 5: + safe_serialize = True + class Trainer: """ @@ -3830,7 +3837,14 @@ def _save_tpu(self, output_dir: Optional[str] = None): if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + if safe_serialize: + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + else: + logger.warning( + f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " + "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true' " + ) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -3927,7 +3941,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): self.processing_class.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + if safe_serialize: + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + else: + logger.warning( + f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " + "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" + ) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) def store_flos(self): # Storing the number of floating-point operations that went into the model @@ -4637,7 +4658,14 @@ def _push_from_checkpoint(self, checkpoint_folder): if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) # Same for the training arguments - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + if safe_serialize: + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + else: + logger.warning( + f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " + "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" + ) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) if self.args.save_strategy == SaveStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" From c7567f707b94d2940a0483d41cd97ea94fa493af Mon Sep 17 00:00:00 2001 From: Lain Date: Wed, 11 Dec 2024 23:01:19 +0100 Subject: [PATCH 07/24] fix version parsing --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 097be54355a..d3e19aecc38 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -313,7 +313,7 @@ def safe_globals(): # Safe serialization check safe_serialize = os.environ.get("TRAINER_SAFE_SERIALIZE") -if safe_serialize or __version__.split(".")[0] >= 5: +if safe_serialize or int(__version__.split(".")[0]) >= 5: safe_serialize = True From 68089ed1af05c695d91b5ae5351f150617e6e065 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 12 Dec 2024 00:07:26 +0100 Subject: [PATCH 08/24] fix deprecated filename --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d3e19aecc38..e5dae058c04 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -302,7 +302,7 @@ def safe_globals(): # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.json" -DEPRECATED_ARGS_NAME = "trainer_state.json" +DEPRECATED_ARGS_NAME = "trainer_state.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" From 2a6ee0d57ca4f25028d598d6c48aaae26ad076a4 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 12 Dec 2024 01:34:34 +0100 Subject: [PATCH 09/24] update test --- tests/trainer/test_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0769cad1e60..14c8e4352de 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -572,7 +572,8 @@ def get_regression_trainer( class TrainerIntegrationCommon: def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME - file_list = [weights_file, "training_args.json", "optimizer.pt", "scheduler.pt", "trainer_state.json"] + file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json"] + safe_serialized = ["training_args.bin", "training_args.json"] # default to json in version 5.x.x if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): @@ -580,6 +581,10 @@ def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, s self.assertTrue(os.path.isdir(checkpoint)) for filename in file_list: self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) + self.assertTrue( + (os.path.isfile(os.path.join(checkpoint, safe_serialized[0]))) + or (os.path.isfile(os.path.join(checkpoint, safe_serialized[1]))) + ) def check_best_model_has_been_loaded( self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True From 8228d0f12e79f8bf576b9a2243f79d3dda77bf70 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 12 Dec 2024 01:37:21 +0100 Subject: [PATCH 10/24] format with ruff --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 14c8e4352de..5f5c72fe671 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -573,7 +573,7 @@ class TrainerIntegrationCommon: def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json"] - safe_serialized = ["training_args.bin", "training_args.json"] # default to json in version 5.x.x + safe_serialized = ["training_args.bin", "training_args.json"] # default to json in version 5.x.x if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): From ba373dc272295e10499c5b907c833bdbe5f29975 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 12 Dec 2024 02:12:35 +0100 Subject: [PATCH 11/24] fix trainer logger tests --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e5dae058c04..a1a8e413264 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3840,7 +3840,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): if safe_serialize: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) else: - logger.warning( + logger.info( f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true' " ) @@ -3944,7 +3944,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if safe_serialize: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) else: - logger.warning( + logger.info( f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" ) @@ -4661,7 +4661,7 @@ def _push_from_checkpoint(self, checkpoint_folder): if safe_serialize: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) else: - logger.warning( + logger.info( f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" ) From 10b05ddfa4ce3a7b16d9555a173b36aa669f1c54 Mon Sep 17 00:00:00 2001 From: Lain Date: Thu, 12 Dec 2024 02:55:49 +0100 Subject: [PATCH 12/24] serialize all parameters --- src/transformers/training_args.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 78f03ac5005..8b33b005e92 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1547,19 +1547,6 @@ class TrainingArguments: }, ) - def __new__(self, *args, **kwargs): - # catch and save only the parameters that the user passed - self.__training_args_params__ = {} - param_names = list(self.__dataclass_fields__.keys()) - - for i in range(len(args)): - self.__training_args_params__[param_names[i]] = serialize_parameter(param_names[i], args[i]) - - for k, v in kwargs.items(): - self.__training_args_params__[k] = serialize_parameter(k, v) - - return super().__new__(self) - def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1581,8 +1568,6 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) - # set logging_dir in __training_args_params__ - self.__training_args_params__["logging_dir"] = self.logging_dir if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN @@ -2552,9 +2537,9 @@ def to_dict(self): def to_json_string(self): """ - Serializes this instance to a JSON string. + Serializes the TrainingArguments into a JSON string. """ - return json.dumps(self.__training_args_params__, indent=2) + return json.dumps(self.to_dict(), indent=2) def to_json_file(self, json_file_path: str): """ @@ -2566,7 +2551,7 @@ def to_json_file(self, json_file_path: str): @classmethod def from_json_file(cls: Type[T], json_file_path: str) -> T: """ - Constructs an instance from a json file + Loads and initializes the TrainingArguments from a json file. """ with open(json_file_path, "r", encoding="utf-8") as reader: params = json.load(reader) From 7633c80365481114467387d9f251f1e3a3ca476e Mon Sep 17 00:00:00 2001 From: Hafedh Hichri <70411813+not-lain@users.noreply.github.com> Date: Tue, 24 Dec 2024 16:55:25 +0100 Subject: [PATCH 13/24] Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a1a8e413264..26fb638af1c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -302,7 +302,7 @@ def safe_globals(): # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.json" -DEPRECATED_ARGS_NAME = "trainer_state.bin" +DEPRECATED_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" From e3c9e57afc2439aa10551b40775e29962183430f Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 27 Dec 2024 19:59:15 +0100 Subject: [PATCH 14/24] Refactor serialization logic in Trainer to use binary serialization flag --- src/transformers/trainer.py | 39 ++++++++++++++----------------- src/transformers/training_args.py | 3 ++- tests/trainer/test_trainer.py | 7 +----- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 26fb638af1c..966705668fc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -310,11 +310,8 @@ def safe_globals(): SCALER_NAME = "scaler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" -# Safe serialization check -safe_serialize = os.environ.get("TRAINER_SAFE_SERIALIZE") - -if safe_serialize or int(__version__.split(".")[0]) >= 5: - safe_serialize = True +# binary serialization check +binary_serializiation = os.environ.get("TRAINER_BINARY_SERIALIZATION", "0") == "1" class Trainer: @@ -3837,13 +3834,13 @@ def _save_tpu(self, output_dir: Optional[str] = None): if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) - if safe_serialize: + if not binary_serializiation: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - else: logger.info( - f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " - "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true' " + "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " ) + else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # Save a trained model and configuration using `save_pretrained()`. @@ -3941,13 +3938,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): self.processing_class.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model - if safe_serialize: - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + if not binary_serializiation: + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + logger.info( + "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + ) else: - logger.info( - f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " - "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" - ) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) def store_flos(self): @@ -4658,13 +4655,13 @@ def _push_from_checkpoint(self, checkpoint_folder): if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) # Same for the training arguments - if safe_serialize: - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + if not binary_serializiation: + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + logger.info( + "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + ) else: - logger.info( - f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, " - "you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'" - ) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) if self.args.save_strategy == SaveStrategy.STEPS: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8b33b005e92..f2d5d23d4e5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2530,8 +2530,9 @@ def to_dict(self): d = {field.name: getattr(self, field.name) for field in fields(self) if field.init} for k, v in d.items(): + # serialize parameters in json compatible format, example : + # converted torch.dtype to string (e.g. torch.float32 -> "float32") d[k] = serialize_parameter(k, v) - self._dict_torch_dtype_to_str(d) return d diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f5c72fe671..cff2f7e0904 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -572,8 +572,7 @@ def get_regression_trainer( class TrainerIntegrationCommon: def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME - file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json"] - safe_serialized = ["training_args.bin", "training_args.json"] # default to json in version 5.x.x + file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json","training_args.json"] if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): @@ -581,10 +580,6 @@ def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, s self.assertTrue(os.path.isdir(checkpoint)) for filename in file_list: self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) - self.assertTrue( - (os.path.isfile(os.path.join(checkpoint, safe_serialized[0]))) - or (os.path.isfile(os.path.join(checkpoint, safe_serialized[1]))) - ) def check_best_model_has_been_loaded( self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True From 689348de120e45f88b54c08dac6c325af896f82c Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 27 Dec 2024 20:02:39 +0100 Subject: [PATCH 15/24] format with ruff --- src/transformers/trainer.py | 20 ++++++++++---------- tests/trainer/test_trainer.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 966705668fc..d9eeae13282 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3939,11 +3939,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): # Good practice: save your training arguments together with the trained model if not binary_serializiation: - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - logger.info( - "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" - "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " - ) + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + logger.info( + "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + ) else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) @@ -4656,11 +4656,11 @@ def _push_from_checkpoint(self, checkpoint_folder): self.processing_class.save_pretrained(output_dir) # Same for the training arguments if not binary_serializiation: - self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - logger.info( - "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" - "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " - ) + self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) + logger.info( + "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + ) else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cff2f7e0904..cc2f198e808 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -572,7 +572,7 @@ def get_regression_trainer( class TrainerIntegrationCommon: def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME - file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json","training_args.json"] + file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json", "training_args.json"] if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): From 2c2198ce3bf1ce5bf769fa43961ad228fefdb64f Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 27 Dec 2024 20:47:23 +0100 Subject: [PATCH 16/24] switch to dynamic serialization --- src/transformers/training_args.py | 48 ++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 550e6395604..a7ff23fc43f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import inspect import io import json import math @@ -21,6 +22,7 @@ from dataclasses import asdict, dataclass, field, fields from datetime import timedelta from enum import Enum +from functools import wraps from pathlib import Path from typing import Any, Dict, List, Optional, Type, TypeVar, Union @@ -69,6 +71,28 @@ T = TypeVar("T", bound="TrainingArguments") + +def serialize(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + bound_args = inspect.signature(func).bind(self, *args, **kwargs) + bound_args.apply_defaults() + + for name, value in bound_args.arguments.items(): + if name not in ["self", "args"] and value is not None: + self.__training_args_params__[name] = serialize_parameter(name, value) + # Handle extra positional arguments + # extra_args = bound_args.arguments.get("args", []) + # if extra_args and extra_args != []: + # print("Extra args: ", extra_args) + # print("Extra args type: ", type(extra_args)) + # self.__training_args_params__["args"] = list(extra_args) + + return func(self, *args, **kwargs) + + return wrapper + + if is_torch_available(): import torch import torch.distributed as dist @@ -1545,6 +1569,19 @@ class TrainingArguments: }, ) + def __new__(self, *args, **kwargs): + # catch and save only the parameters that the user passed + self.__training_args_params__ = {} + param_names = list(self.__dataclass_fields__.keys()) + + for i in range(len(args)): + self.__training_args_params__[param_names[i]] = serialize_parameter(param_names[i], args[i]) + + for k, v in kwargs.items(): + self.__training_args_params__[k] = serialize_parameter(k, v) + + return super().__new__(self) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -2526,7 +2563,7 @@ def to_json_string(self): """ Serializes the TrainingArguments into a JSON string. """ - return json.dumps(self.to_dict(), indent=2) + return json.dumps(self.__training_args_params__, indent=2) def to_json_file(self, json_file_path: str): """ @@ -2558,6 +2595,7 @@ def to_sanitized_dict(self) -> Dict[str, Any]: return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} # The following methods are there to simplify the instantiation of `TrainingArguments` + @serialize def set_training( self, learning_rate: float = 5e-5, @@ -2633,6 +2671,7 @@ def set_training( self.gradient_checkpointing = gradient_checkpointing return self + @serialize def set_evaluate( self, strategy: Union[str, IntervalStrategy] = "no", @@ -2694,6 +2733,7 @@ def set_evaluate( self.jit_mode_eval = jit_mode return self + @serialize def set_testing( self, batch_size: int = 8, @@ -2734,6 +2774,7 @@ def set_testing( self.jit_mode_eval = jit_mode return self + @serialize def set_save( self, strategy: Union[str, IntervalStrategy] = "steps", @@ -2783,6 +2824,7 @@ def set_save( self.save_on_each_node = on_each_node return self + @serialize def set_logging( self, strategy: Union[str, IntervalStrategy] = "steps", @@ -2858,6 +2900,7 @@ def set_logging( self.log_level_replica = replica_level return self + @serialize def set_push_to_hub( self, model_id: str, @@ -2928,6 +2971,7 @@ def set_push_to_hub( self.hub_always_push = always_push return self + @serialize def set_optimizer( self, name: Union[str, OptimizerNames] = "adamw_torch", @@ -2979,6 +3023,7 @@ def set_optimizer( self.optim_args = args return self + @serialize def set_lr_scheduler( self, name: Union[str, SchedulerType] = "linear", @@ -3024,6 +3069,7 @@ def set_lr_scheduler( self.warmup_steps = warmup_steps return self + @serialize def set_dataloader( self, train_batch_size: int = 8, From 9741c8f3e555ba3db8644c414b7f3e3480f988cd Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 27 Dec 2024 21:27:19 +0100 Subject: [PATCH 17/24] Refactor training_args serialization to only include defined dataclass fields --- src/transformers/training_args.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a7ff23fc43f..b6d67856f47 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -80,13 +80,8 @@ def wrapper(self, *args, **kwargs): for name, value in bound_args.arguments.items(): if name not in ["self", "args"] and value is not None: - self.__training_args_params__[name] = serialize_parameter(name, value) - # Handle extra positional arguments - # extra_args = bound_args.arguments.get("args", []) - # if extra_args and extra_args != []: - # print("Extra args: ", extra_args) - # print("Extra args type: ", type(extra_args)) - # self.__training_args_params__["args"] = list(extra_args) + if name in self.__dataclass_fields__.keys(): + self.__training_args_params__[name] = serialize_parameter(name, value) return func(self, *args, **kwargs) From ef160161132fed60e67e73b39d7de4d8b100a810 Mon Sep 17 00:00:00 2001 From: Lain Date: Fri, 27 Dec 2024 21:34:40 +0100 Subject: [PATCH 18/24] ensure consistant traininglogs to ensure that tensorboard logs continue from where they left off --- src/transformers/training_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b6d67856f47..b14da1e6c96 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1598,6 +1598,8 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) + # set logging_dir in __training_args_params__ + self.__training_args_params__["logging_dir"] = self.logging_dir if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN From d6d24790ab00f1cf52b1b39ad0dea562e035dfba Mon Sep 17 00:00:00 2001 From: Hafedh Hichri <70411813+not-lain@users.noreply.github.com> Date: Tue, 21 Jan 2025 18:31:30 +0100 Subject: [PATCH 19/24] Update src/transformers/training_args.py --- src/transformers/training_args.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b14da1e6c96..cdea54b8b9f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -3146,6 +3146,15 @@ class ParallelMode(Enum): def serialize_parameter(k, v): + """ + Serializes a parameter based on its type and key. + This function takes a key and value, and serializes the value depending on its type. + Args: + k: The key associated with the parameter. + v: The value to be serialized, which can be of various types. + Returns:: + The serialized value, which depends on its original type and key. + """ if k == "torch_dtype" and not isinstance(v, str): return str(v).split(".")[1] if isinstance(v, dict): From d053a88f1a7e15d5af1f44a22503feb8b148fdd4 Mon Sep 17 00:00:00 2001 From: Hafedh Hichri <70411813+not-lain@users.noreply.github.com> Date: Tue, 21 Jan 2025 18:31:43 +0100 Subject: [PATCH 20/24] Update src/transformers/training_args.py --- src/transformers/training_args.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cdea54b8b9f..b2aa71e14eb 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -73,6 +73,13 @@ def serialize(func): + """ + A decorator that captures and serializes the parameters of any method that was called from the original class and stores all valid parameters in the `__training_args_params__` attribute. + Args: + func: The function to be decorated. + Returns: + The wrapped function with serialized parameters. + """ @wraps(func) def wrapper(self, *args, **kwargs): bound_args = inspect.signature(func).bind(self, *args, **kwargs) From a0d9320f4ed5a96d178538d1310b9103d96253f8 Mon Sep 17 00:00:00 2001 From: Lain Date: Tue, 21 Jan 2025 18:35:24 +0100 Subject: [PATCH 21/24] update with ruff --- src/transformers/training_args.py | 33 ++++++++++++++++--------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fcdd6150f4b..c23482df44f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -73,13 +73,14 @@ def serialize(func): - """ - A decorator that captures and serializes the parameters of any method that was called from the original class and stores all valid parameters in the `__training_args_params__` attribute. - Args: - func: The function to be decorated. - Returns: - The wrapped function with serialized parameters. - """ + """ + A decorator that captures and serializes the parameters of any method that was called from the original class and stores all valid parameters in the `__training_args_params__` attribute. + Args: + func: The function to be decorated. + Returns: + The wrapped function with serialized parameters. + """ + @wraps(func) def wrapper(self, *args, **kwargs): bound_args = inspect.signature(func).bind(self, *args, **kwargs) @@ -3156,15 +3157,15 @@ class ParallelMode(Enum): def serialize_parameter(k, v): - """ - Serializes a parameter based on its type and key. - This function takes a key and value, and serializes the value depending on its type. - Args: - k: The key associated with the parameter. - v: The value to be serialized, which can be of various types. - Returns:: - The serialized value, which depends on its original type and key. - """ + """ + Serializes a parameter based on its type and key. + This function takes a key and value, and serializes the value depending on its type. + Args: + k: The key associated with the parameter. + v: The value to be serialized, which can be of various types. + Returns: + The serialized value, which depends on its original type and key. + """ if k == "torch_dtype" and not isinstance(v, str): return str(v).split(".")[1] if isinstance(v, dict): From bdcbfc0084cdf995dea00f3b2a0f46f52dc204c2 Mon Sep 17 00:00:00 2001 From: Lain Date: Tue, 21 Jan 2025 19:05:23 +0100 Subject: [PATCH 22/24] switch from logging info to logging warning --- src/transformers/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 904e2683bfc..766d11dc262 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3963,9 +3963,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): # Good practice: save your training arguments together with the trained model if not binary_serializiation: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - logger.info( - "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" - "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + logger.warn( + "Safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1', this feature will be removed in version 5.0.0" ) else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) @@ -4680,9 +4680,9 @@ def _push_from_checkpoint(self, checkpoint_folder): # Same for the training arguments if not binary_serializiation: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - logger.info( - "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" - "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + logger.warn( + "Safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1', this feature will be removed in version 5.0.0" ) else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) From 330cded009534345eaf74c72192e9f5b14cf2a4f Mon Sep 17 00:00:00 2001 From: Lain Date: Tue, 21 Jan 2025 19:32:23 +0100 Subject: [PATCH 23/24] Change logger from info to warning for safe serialization message and update test assertion for warning capture --- src/transformers/trainer.py | 6 +++--- tests/trainer/test_trainer_callback.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 766d11dc262..816712446e5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3859,9 +3859,9 @@ def _save_tpu(self, output_dir: Optional[str] = None): os.makedirs(output_dir, exist_ok=True) if not binary_serializiation: self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME)) - logger.info( - "safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" - "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1' " + logger.warn( + "Safe serialization has been enabled by default and the training args will be stored in a {TRAINING_ARGS_NAME} \n" + "to fall back to the {DEPRECATED_ARGS_NAME} you can set os.environ['binary_serializiation']= '1', this feature will be removed in version 5.0.0" ) else: torch.save(self.args, os.path.join(output_dir, "training_args.bin")) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 0d1e6645f9a..c58a5c02c8c 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -410,7 +410,7 @@ def test_missing_stateful_callback(self): # warning should be emitted for not-present callbacks with patch("transformers.trainer.logger.warning") as warn_mock: trainer.train(resume_from_checkpoint=checkpoint) - assert "EarlyStoppingCallback" in warn_mock.call_args[0][0] + assert "EarlyStoppingCallback" in [warn_mock.call_args[0][0] or warn_mock.call_args[0][1]] def test_stateful_control(self): trainer = self.get_trainer( From 45db5d6d68fc81fccb084ebb0fdf6f5751e5be85 Mon Sep 17 00:00:00 2001 From: Lain Date: Tue, 21 Jan 2025 19:57:47 +0100 Subject: [PATCH 24/24] debugging stateful callback --- tests/trainer/test_trainer_callback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index c58a5c02c8c..097745ac67b 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -410,6 +410,7 @@ def test_missing_stateful_callback(self): # warning should be emitted for not-present callbacks with patch("transformers.trainer.logger.warning") as warn_mock: trainer.train(resume_from_checkpoint=checkpoint) + print("warn_mock.call_args = ", warn_mock.call_args) assert "EarlyStoppingCallback" in [warn_mock.call_args[0][0] or warn_mock.call_args[0][1]] def test_stateful_control(self):