From 348dec9e87a7e0f42171756b574b3bfefba2acb7 Mon Sep 17 00:00:00 2001 From: Dongfu Date: Fri, 24 Nov 2023 02:19:48 -0500 Subject: [PATCH] update --- README.md | 1 - llm_blender/pair_ranker/trainer.py | 4 ++-- train_ranker.py | 14 +++++++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6fc245d..602a9d1 100755 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ ## 🔥News - [11/10] Glad to announce that our pairwise reward-model, 🤗[PairRM](https://huggingface.co/llm-blender/PairRM), has released. It's trained on high-quality and large-scale human reference dataset and approaches GPT-4's alignment with human preference with a extremly small model size (0.4B). -- [10/28] We release a newly trained PairRanker used for reward model at 🤗 [llm-blender/pair-reward-model](https://huggingface.co/llm-blender/pair-reward-model) - [10/24] Pre-trained PairRanker is able to be loaded directly from 🤗 Hugging face Models [llm-blender/PairRM](https://huggingface.co/llm-blender/PairRM) within 3 lines of code. See Guidance for [Rank & Fusion](#rank-and-fusion) for details. diff --git a/llm_blender/pair_ranker/trainer.py b/llm_blender/pair_ranker/trainer.py index fa62beb..082b011 100755 --- a/llm_blender/pair_ranker/trainer.py +++ b/llm_blender/pair_ranker/trainer.py @@ -28,8 +28,8 @@ def save_model(self, output_dir: Optional[str] = None, **kwargs): if self.is_world_process_zero(): super().save_model(output_dir, **kwargs) model = self.model.module if hasattr(self.model, "module") else self.model - json.dump(asdict(model.args), open(os.path.join(output_dir, "ranker_config.json"), "w"), indent=4) - + json.dump(asdict(model.args), open(os.path.join(output_dir, "config.json"), "w"), indent=4) + class FiDTrainer(Seq2SeqTrainer): def compute_loss(self, model, inputs, return_outputs=False): """ diff --git a/train_ranker.py b/train_ranker.py index b2c722f..1f0b4ea 100755 --- a/train_ranker.py +++ b/train_ranker.py @@ -84,13 +84,13 @@ def main(args): args.n_tasks = predict_dataset.n_tasks # set up model - config = RankerConfig() - for k in args.__dict__: - if k in config.__dict__: - print(k, getattr(args, k)) - setattr(config, k, getattr(args, k)) + if args.load_checkpoint: - # config = torch.load(os.path.join(args.load_checkpoint, "config.bin")) + config = RankerConfig.from_json_file(os.path.join(args.load_checkpoint, "config.json")) + for k in args.__dict__: + if k in config.__dict__: + print(k, getattr(args, k)) + setattr(config, k, getattr(args, k)) model = build_ranker( args.ranker_type, args.model_type, @@ -106,6 +106,7 @@ def main(args): else: logging.info(f"Successfully loaded checkpoint from '{args.load_checkpoint}'") else: + config = RankerConfig() model = build_ranker( args.ranker_type, args.model_type, @@ -201,7 +202,6 @@ def main(args): logging.info("Saving model") best_checkpoint_folder = os.path.join(args.output_dir, "checkpoint-best") trainer.save_model(best_checkpoint_folder) - torch.save(model.args, os.path.join(best_checkpoint_folder, "config.bin")) if args.do_predict: logging.info("Start predicting")