From dde02fcb94e16206ab9ec6d3754ddce8d6fcf0d8 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Fri, 3 May 2024 20:35:28 +0530 Subject: [PATCH] Pass weakref to model in the SIGINT handler to free up model post train function (#1581) * Pass weakref to model in the SIGINT handler to free up model post train() * Fix lint issues * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/train.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 01e07640f9..ebd020061b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -3,6 +3,7 @@ import os import signal import sys +import weakref from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple, Union @@ -127,14 +128,20 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: - def terminate_handler(_, __, model): - if cfg.flash_optimum and BetterTransformer: - model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + def terminate_handler(_, __, model_weakref): + if model_weakref() is not None: + _model = model_weakref() + if cfg.flash_optimum and BetterTransformer: + _model = BetterTransformer.reverse(_model) + _model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) sys.exit(0) + _model_weakref = weakref.ref(model) signal.signal( - signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) + signal.SIGINT, + lambda signum, frame: terminate_handler(signum, frame, _model_weakref), ) badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)"""