From 1d21aa6b0ac0e1de832b5d57c82da34220346046 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Nov 2023 09:55:19 -0500 Subject: [PATCH] ensure merged model matches the training dtype (#902) * ensure merged model matches the training dtype * Update src/axolotl/cli/__init__.py * Update src/axolotl/cli/__init__.py --- src/axolotl/cli/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 7ce4f19488..8ca4f7fe55 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -72,7 +72,7 @@ def do_merge_lora( LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() - model.to(dtype=torch.float16) + model.to(dtype=cfg.torch_dtype) if cfg.local_rank == 0: LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")