diff --git a/trialbot/__init__.py b/trialbot/__init__.py index 63b8d7b..0d419b3 100644 --- a/trialbot/__init__.py +++ b/trialbot/__init__.py @@ -1,4 +1,4 @@ __name__ = "TrialBot" -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/trialbot/training/trial_bot.py b/trialbot/training/trial_bot.py index 2dbf55e..caadd12 100644 --- a/trialbot/training/trial_bot.py +++ b/trialbot/training/trial_bot.py @@ -30,9 +30,13 @@ class Events(Enum): EXCEPTION_RAISED = "exception_raised" -class State(object): +class State: """An object that is used to pass internal and user-defined state between event handlers.""" def __init__(self, **kwargs): + self.epoch = 0 + self.iteration = 0 + self.output = None + for k, v in kwargs.items(): setattr(self, k, v) @@ -74,36 +78,32 @@ class TrialBot: """ def __init__(self, args: argparse.Namespace | None = None, - trial_name: str = "default_savedir", - get_model_func: Optional[Callable[[Any, NSVocabulary], torch.nn.Module]] = None, + trial_name: str = "trial", + get_model_func: Callable[[Any, NSVocabulary], torch.nn.Module] | None = None, clean_engine: bool = False, + runtime_hooks: dict[str, Callable] | None = None, ): - if args is None: - parser = TrialBot.get_default_parser() - args = parser.parse_args() + self.runtime_hooks = dict() if runtime_hooks is None else runtime_hooks - self.args = args + self.args = self._init_args(args) # if given, use the given args, otherwise it will parse from sys.argv self.name = trial_name - - # self.translator = None - # self.vocab = None - # self.train_set = None - # self.dev_set = None - # self.test_set = None - self.hparams = None - self.savepath = "." + self.hparams = self._init_hparams() + self.savepath = self._init_savepath() self.logger = logging.getLogger(__name__) self.updater = None self.state = State(epoch=0, iteration=0, output=None) - self.get_model = get_model_func + self.engine = self._init_engine(clean=clean_engine) + self.datasets = self._init_dataset() + self.translator = self._init_translator() + self.vocab = self._init_vocab(self.train_set, self.translator) + if self.translator and self.vocab: + self.translator.index_with_vocab(self.vocab) + self.models = self._init_models(self.hparams, self.vocab) - self._engine = self._make_engine(clean=clean_engine) - self._init_components() - - def _make_engine(self, clean: bool = False): + def _init_engine(self, clean: bool = False): engine = Engine() engine.register_events(*Events) # events with greater priorities will get processed earlier. @@ -123,6 +123,7 @@ def _make_engine(self, clean: bool = False): engine.add_event_handler(Events.COMPLETED, exts.ext_write_info, 100, msg="TrailBot completed.") engine.add_event_handler(Events.ITERATION_COMPLETED, exts.loss_reporter, 100) engine.add_event_handler(Events.ITERATION_COMPLETED, exts.end_with_nan_loss, 100) + self.engine = engine return engine @property @@ -141,11 +142,10 @@ def get_default_parser(): help='overwrite the hparamset.TRAINING_LIMIT if specified at CLI') return parser - def _init_components(self): - """ - Start a trial directly. - """ - args = self.args + def _init_args(self, args=None): + if args is None: + parser = self.get_default_parser() + args = parser.parse_args() # logging args if args.quiet: @@ -155,6 +155,11 @@ def _init_components(self): else: self.logger.setLevel(logging.INFO) + self.args = args + return args + + def _init_hparams(self): + args = self.args hparams = Registry.get_hparamset(args.hparamset) if args.batch_size > 0: hparams.batch_sz = args.batch_size @@ -162,31 +167,27 @@ def _init_components(self): hparams.TRAINING_LIMIT = args.epoch if args.translator: hparams.TRANSLATOR = args.translator + if 'hparams' in self.runtime_hooks: + hparams = self.runtime_hooks['hparams'](hparams) self.hparams = hparams - self.savepath = args.snapshot_dir if args.snapshot_dir else self._default_savepath() - - self._init_dataset() - self._init_translator() - - self.vocab = self._init_vocab(self.train_set, self.translator) - - self.translator.index_with_vocab(self.vocab) - self.models = self._init_models(hparams, self.vocab) + return hparams def _init_dataset(self): args = self.args self.datasets = Registry.get_dataset(args.dataset) return self.datasets - def _default_savepath(self): + def _init_savepath(self): args = self.args - if args.test and len(args.models) > 0: # use the first model path as the savepath, # if this is not expected, please specify manually the return os.path.abspath(os.path.expanduser(os.path.dirname(args.models[0]))) + if args.snapshot_dir: + return args.snapshot_dir + return os.path.join( self.hparams.SNAPSHOT_PATH, args.dataset, @@ -217,9 +218,7 @@ def _init_translator(self): self.translator = translator return translator - def _init_vocab(self, - dataset: Dataset, - translator: Translator): + def _init_vocab(self, dataset: Dataset, translator: Translator): args, logger = self.args, self.logger hparams = self.hparams @@ -256,7 +255,6 @@ def _init_vocab(self, vocab.save_to_files(vocab_path) logger.info(str(vocab)) - return vocab def _init_models(self, hparams, vocab): @@ -305,55 +303,45 @@ def run(self, training_epoch: int = 0): self._training_engine_loop(self.updater, training_epoch) def _training_engine_loop(self, updater, max_epoch): - engine = self._engine + engine = self.engine engine.fire_event(Events.STARTED, bot=self) while self.state.epoch < max_epoch: self.state.epoch += 1 - engine.fire_event(Events.EPOCH_STARTED, bot=self) - updater.start_epoch() - while True: - self.state.iteration += 1 - engine.fire_event(Events.ITERATION_STARTED, bot=self) - try: - self.state.output = next(updater) - engine.fire_event(Events.ITERATION_COMPLETED, bot=self) - except StopIteration: - self.state.iteration -= 1 - self.state.output = None - break - - engine.fire_event(Events.EPOCH_COMPLETED, bot=self) + self._epoch_loop(updater) engine.fire_event(Events.COMPLETED, bot=self) def _testing_engine_loop(self, updater): - engine = self._engine + engine = self.engine with torch.no_grad(): engine.fire_event(Events.STARTED, bot=self) - engine.fire_event(Events.EPOCH_STARTED, bot=self) - updater.start_epoch() - while True: - self.state.iteration += 1 - engine.fire_event(Events.ITERATION_STARTED, bot=self) - try: - self.state.output = next(updater) - engine.fire_event(Events.ITERATION_COMPLETED, bot=self) - except StopIteration: - self.state.iteration -= 1 - self.state.output = None - break - - engine.fire_event(Events.EPOCH_COMPLETED, bot=self) + self._epoch_loop(updater) engine.fire_event(Events.COMPLETED, bot=self) + def _epoch_loop(self, updater): + engine = self.engine + engine.fire_event(Events.EPOCH_STARTED, bot=self) + updater.start_epoch() + while True: + self.state.iteration += 1 + engine.fire_event(Events.ITERATION_STARTED, bot=self) + try: + self.state.output = next(updater) + engine.fire_event(Events.ITERATION_COMPLETED, bot=self) + except StopIteration: + self.state.iteration -= 1 + self.state.output = None + break + engine.fire_event(Events.EPOCH_COMPLETED, bot=self) + def attach_extension(self, event_name: str = Events.ITERATION_COMPLETED, priority: int = 100,): """Used as a decorator only. To add extension directly, use add_event_handler instead.""" def decorator(handler): - self._engine.add_event_handler(event_name, handler, priority) + self.engine.add_event_handler(event_name, handler, priority) return handler return decorator def add_event_handler(self, event_name, handler, priority=100, *args, **kwargs): - self._engine.add_event_handler(event_name, handler, priority, *args, **kwargs) + self.engine.add_event_handler(event_name, handler, priority, *args, **kwargs)