Skip to content

Commit

Permalink
v0.5.0;
Browse files Browse the repository at this point in the history
add hooks after hparamset init.;
refactor the components init.;
rename the _engine to engine by removing the proceeding underscore;
  • Loading branch information
zxteloiv committed Sep 25, 2023
1 parent 04d41c7 commit 75664dd
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 73 deletions.
2 changes: 1 addition & 1 deletion trialbot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

__name__ = "TrialBot"
__version__ = "0.4.0"
__version__ = "0.5.0"

132 changes: 60 additions & 72 deletions trialbot/training/trial_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ class Events(Enum):
EXCEPTION_RAISED = "exception_raised"


class State(object):
class State:
"""An object that is used to pass internal and user-defined state between event handlers."""
def __init__(self, **kwargs):
self.epoch = 0
self.iteration = 0
self.output = None

for k, v in kwargs.items():
setattr(self, k, v)

Expand Down Expand Up @@ -74,36 +78,32 @@ class TrialBot:
"""
def __init__(self,
args: argparse.Namespace | None = None,
trial_name: str = "default_savedir",
get_model_func: Optional[Callable[[Any, NSVocabulary], torch.nn.Module]] = None,
trial_name: str = "trial",
get_model_func: Callable[[Any, NSVocabulary], torch.nn.Module] | None = None,
clean_engine: bool = False,
runtime_hooks: dict[str, Callable] | None = None,
):
if args is None:
parser = TrialBot.get_default_parser()
args = parser.parse_args()
self.runtime_hooks = dict() if runtime_hooks is None else runtime_hooks

self.args = args
self.args = self._init_args(args) # if given, use the given args, otherwise it will parse from sys.argv
self.name = trial_name

# self.translator = None
# self.vocab = None
# self.train_set = None
# self.dev_set = None
# self.test_set = None
self.hparams = None
self.savepath = "."
self.hparams = self._init_hparams()
self.savepath = self._init_savepath()
self.logger = logging.getLogger(__name__)

self.updater = None

self.state = State(epoch=0, iteration=0, output=None)

self.get_model = get_model_func
self.engine = self._init_engine(clean=clean_engine)
self.datasets = self._init_dataset()
self.translator = self._init_translator()
self.vocab = self._init_vocab(self.train_set, self.translator)
if self.translator and self.vocab:
self.translator.index_with_vocab(self.vocab)
self.models = self._init_models(self.hparams, self.vocab)

self._engine = self._make_engine(clean=clean_engine)
self._init_components()

def _make_engine(self, clean: bool = False):
def _init_engine(self, clean: bool = False):
engine = Engine()
engine.register_events(*Events)
# events with greater priorities will get processed earlier.
Expand All @@ -123,6 +123,7 @@ def _make_engine(self, clean: bool = False):
engine.add_event_handler(Events.COMPLETED, exts.ext_write_info, 100, msg="TrailBot completed.")
engine.add_event_handler(Events.ITERATION_COMPLETED, exts.loss_reporter, 100)
engine.add_event_handler(Events.ITERATION_COMPLETED, exts.end_with_nan_loss, 100)
self.engine = engine
return engine

@property
Expand All @@ -141,11 +142,10 @@ def get_default_parser():
help='overwrite the hparamset.TRAINING_LIMIT if specified at CLI')
return parser

def _init_components(self):
"""
Start a trial directly.
"""
args = self.args
def _init_args(self, args=None):
if args is None:
parser = self.get_default_parser()
args = parser.parse_args()

# logging args
if args.quiet:
Expand All @@ -155,38 +155,39 @@ def _init_components(self):
else:
self.logger.setLevel(logging.INFO)

self.args = args
return args

def _init_hparams(self):
args = self.args
hparams = Registry.get_hparamset(args.hparamset)
if args.batch_size > 0:
hparams.batch_sz = args.batch_size
if args.epoch > 0:
hparams.TRAINING_LIMIT = args.epoch
if args.translator:
hparams.TRANSLATOR = args.translator
if 'hparams' in self.runtime_hooks:
hparams = self.runtime_hooks['hparams'](hparams)

self.hparams = hparams
self.savepath = args.snapshot_dir if args.snapshot_dir else self._default_savepath()

self._init_dataset()
self._init_translator()

self.vocab = self._init_vocab(self.train_set, self.translator)

self.translator.index_with_vocab(self.vocab)
self.models = self._init_models(hparams, self.vocab)
return hparams

def _init_dataset(self):
args = self.args
self.datasets = Registry.get_dataset(args.dataset)
return self.datasets

def _default_savepath(self):
def _init_savepath(self):
args = self.args

if args.test and len(args.models) > 0:
# use the first model path as the savepath,
# if this is not expected, please specify manually the
return os.path.abspath(os.path.expanduser(os.path.dirname(args.models[0])))

if args.snapshot_dir:
return args.snapshot_dir

return os.path.join(
self.hparams.SNAPSHOT_PATH,
args.dataset,
Expand Down Expand Up @@ -217,9 +218,7 @@ def _init_translator(self):
self.translator = translator
return translator

def _init_vocab(self,
dataset: Dataset,
translator: Translator):
def _init_vocab(self, dataset: Dataset, translator: Translator):
args, logger = self.args, self.logger
hparams = self.hparams

Expand Down Expand Up @@ -256,7 +255,6 @@ def _init_vocab(self,
vocab.save_to_files(vocab_path)

logger.info(str(vocab))

return vocab

def _init_models(self, hparams, vocab):
Expand Down Expand Up @@ -305,55 +303,45 @@ def run(self, training_epoch: int = 0):
self._training_engine_loop(self.updater, training_epoch)

def _training_engine_loop(self, updater, max_epoch):
engine = self._engine
engine = self.engine
engine.fire_event(Events.STARTED, bot=self)
while self.state.epoch < max_epoch:
self.state.epoch += 1
engine.fire_event(Events.EPOCH_STARTED, bot=self)
updater.start_epoch()
while True:
self.state.iteration += 1
engine.fire_event(Events.ITERATION_STARTED, bot=self)
try:
self.state.output = next(updater)
engine.fire_event(Events.ITERATION_COMPLETED, bot=self)
except StopIteration:
self.state.iteration -= 1
self.state.output = None
break

engine.fire_event(Events.EPOCH_COMPLETED, bot=self)
self._epoch_loop(updater)
engine.fire_event(Events.COMPLETED, bot=self)

def _testing_engine_loop(self, updater):
engine = self._engine
engine = self.engine
with torch.no_grad():
engine.fire_event(Events.STARTED, bot=self)
engine.fire_event(Events.EPOCH_STARTED, bot=self)
updater.start_epoch()
while True:
self.state.iteration += 1
engine.fire_event(Events.ITERATION_STARTED, bot=self)
try:
self.state.output = next(updater)
engine.fire_event(Events.ITERATION_COMPLETED, bot=self)
except StopIteration:
self.state.iteration -= 1
self.state.output = None
break

engine.fire_event(Events.EPOCH_COMPLETED, bot=self)
self._epoch_loop(updater)
engine.fire_event(Events.COMPLETED, bot=self)

def _epoch_loop(self, updater):
engine = self.engine
engine.fire_event(Events.EPOCH_STARTED, bot=self)
updater.start_epoch()
while True:
self.state.iteration += 1
engine.fire_event(Events.ITERATION_STARTED, bot=self)
try:
self.state.output = next(updater)
engine.fire_event(Events.ITERATION_COMPLETED, bot=self)
except StopIteration:
self.state.iteration -= 1
self.state.output = None
break
engine.fire_event(Events.EPOCH_COMPLETED, bot=self)

def attach_extension(self,
event_name: str = Events.ITERATION_COMPLETED,
priority: int = 100,):
"""Used as a decorator only. To add extension directly, use add_event_handler instead."""
def decorator(handler):
self._engine.add_event_handler(event_name, handler, priority)
self.engine.add_event_handler(event_name, handler, priority)
return handler
return decorator

def add_event_handler(self, event_name, handler, priority=100, *args, **kwargs):
self._engine.add_event_handler(event_name, handler, priority, *args, **kwargs)
self.engine.add_event_handler(event_name, handler, priority, *args, **kwargs)

0 comments on commit 75664dd

Please sign in to comment.