Skip to content

Commit

Permalink
[TRAX] v1.3.3 and Store checkpoint with unreplicated weights/state in…
Browse files Browse the repository at this point in the history
… Loop.

PiperOrigin-RevId: 323172102
  • Loading branch information
afrozenator authored and copybara-github committed Jul 25, 2020
1 parent 023dbce commit 88b033c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 37 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

setup(
name='trax',
version='1.3.2',
version='1.3.3',
description='Trax',
long_description=(
'Trax helps you understand deep learning. We start with basic maths and'
Expand Down
83 changes: 47 additions & 36 deletions trax/supervised/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,36 +159,7 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
# unnecessary, i.e. random_seed was set.
if random_seed is None and self._n_hosts > 1:
logging.info('Syncing weights/state across %d hosts.', self._n_hosts)

if logging.vlog_is_on(1):
logging.info(
'Input training weights shape: %s',
fastmath.nested_map(lambda x: x.shape,
self._model_in_training.weights))
logging.info('Input training weights: %s',
self._model_in_training.weights)
logging.info('Input training state: %s', self._model_in_training.state)
logging.info('Input eval weights: %s', self._eval_model.weights)
logging.info('Input eval state: %s', self._eval_model.state)

(self._model_in_training.weights, self._model_in_training.state,
self._eval_model.weights, self._eval_model.state) = self._unreplicate(
_make_weights_and_state_same_across_hosts(
self._for_n_devices(
(self._model_in_training.weights,
self._model_in_training.state, self._eval_model.weights,
self._eval_model.state))))

if logging.vlog_is_on(1):
logging.info(
'Output training weights shape: %s',
fastmath.nested_map(lambda x: x.shape,
self._model_in_training.weights))
logging.info('Output training weights: %s',
self._model_in_training.weights)
logging.info('Output training state: %s', self._model_in_training.state)
logging.info('Output eval weights: %s', self._eval_model.weights)
logging.info('Output eval state: %s', self._eval_model.state)
self._sync_weights_and_state_across_hosts()

self._task.optimizer.tree_init(self._model_in_training.weights)

Expand Down Expand Up @@ -236,6 +207,39 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
if self._output_dir is None:
_log('Will not write evaluation metrics, because output_dir is None.')

def _sync_weights_and_state_across_hosts(self):
"""Sync weights and state across all the hosts in the computation."""

if logging.vlog_is_on(1):
logging.debug(
'Input training weights shape: %s',
fastmath.nested_map(lambda x: x.shape,
self._model_in_training.weights))
logging.debug('Input training weights: %s',
self._model_in_training.weights)
logging.debug('Input training state: %s', self._model_in_training.state)
logging.debug('Input eval weights: %s', self._eval_model.weights)
logging.debug('Input eval state: %s', self._eval_model.state)

(self._model_in_training.weights, self._model_in_training.state,
self._eval_model.weights, self._eval_model.state) = self._unreplicate(
_make_weights_and_state_same_across_hosts(
self._for_n_devices(
(self._model_in_training.weights,
self._model_in_training.state, self._eval_model.weights,
self._eval_model.state))))

if logging.vlog_is_on(1):
logging.debug(
'Output training weights shape: %s',
fastmath.nested_map(lambda x: x.shape,
self._model_in_training.weights))
logging.debug('Output training weights: %s',
self._model_in_training.weights)
logging.debug('Output training state: %s', self._model_in_training.state)
logging.debug('Output eval weights: %s', self._eval_model.weights)
logging.debug('Output eval state: %s', self._eval_model.state)

def run(self, n_steps=1):
"""Runs this training loop for n steps.
Expand Down Expand Up @@ -280,12 +284,20 @@ def run(self, n_steps=1):
step_acc += 1
for metric_name, value in optimizer_metrics.items():
optimizer_metrics_acc[metric_name] += value
if self._checkpoint_at(self.step):
self.save_checkpoint(weights, state, slots)
if self._eval_at(self.step):

should_checkpoint = self._checkpoint_at(self.step)
should_eval = self._eval_at(self.step)
unr_weights, unr_state, unr_slots = None, None, None
if should_checkpoint or should_eval:
unr_weights, unr_state, unr_slots = self._unreplicate(
(weights, state, slots))

if should_checkpoint:
self.save_checkpoint(unr_weights, unr_state, unr_slots)
if should_eval:
elapsed_time = time.time() - start_time
self._model_in_training.weights = weights
self._model_in_training.state = state
self._model_in_training.weights = unr_weights
self._model_in_training.state = unr_state
self._eval_model.weights = self._model.weights
self._log_training_progress(
total_loss=loss_acc, n_steps=step_acc, elapsed_time=elapsed_time,
Expand Down Expand Up @@ -387,7 +399,6 @@ def _run_one_step(self, weights, state, slots, opt_params):
if logging.vlog_is_on(1) and ((step & step - 1) == 0):
# Prints every power of two, if debugging is enabled.
logging.info('step[%d]', step)
# logging.info('batch[%s]', batch)
logging.info('opt_params[%s]', opt_params)
logging.info('weights[%s]', weights)

Expand Down

0 comments on commit 88b033c

Please sign in to comment.