Skip to content

Commit

Permalink
[TRAX] Fixing a few OSS test errors (There maybe further errors, but …
Browse files Browse the repository at this point in the history
…this is good to go):

 - ConfigTest needs to iterate on .gin and not .gin + .yaml
 - `if dm_suite` then set eval_env and env to None, otherwise ConfigTest complains.
 - Explicitly pass in extra_ids to t5's SPC vocab - otherwise is computes with None.

PiperOrigin-RevId: 348188541
  • Loading branch information
afrozenator authored and copybara-github committed Dec 18, 2020
1 parent 864a5a5 commit 7397898
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
11 changes: 8 additions & 3 deletions trax/data/tf_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None):
return text_encoder.BertEncoder(path, do_lower_case=True)

assert vocab_type == 'sentencepiece'
return t5.data.SentencePieceVocabulary(sentencepiece_model_file=path)
return t5.data.SentencePieceVocabulary(sentencepiece_model_file=path,
extra_ids=0)


# Makes the function accessible in gin configs, even with all args denylisted.
Expand Down Expand Up @@ -788,8 +789,10 @@ def c4_bare_preprocess_fn(dataset,
})

# Vocabulary for tokenization.
extra_ids = 0
vocab = t5.data.SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH)
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH,
extra_ids=extra_ids)
feature = t5.data.Feature(vocab)
output_features = {'targets': feature, 'inputs': feature}

Expand Down Expand Up @@ -995,8 +998,10 @@ def print_examples(x):
dataset = dataset.map(print_examples)

# Vocabulary for tokenization.
extra_ids = 0
vocab = t5.data.SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH)
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH,
extra_ids=extra_ids)
feature = t5.data.Feature(vocab)
output_features = {'targets': feature, 'inputs': feature}

Expand Down
43 changes: 31 additions & 12 deletions trax/rl/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def list_configs():
return [
# (config name without extension, config path)
(os.path.splitext(config)[0], os.path.join(config_dir, config))
for config in os.listdir(config_dir)
for config in os.listdir(config_dir) if config.endswith('.gin')
]


Expand All @@ -43,18 +43,37 @@ def test_dry_run(self, config):
"""Dry-runs all gin configs."""
gin.clear_config(clear_constants=True)
gin.parse_config_file(config)
def run_config():
try:
rl_trainer.train_rl(
output_dir=self.create_tempdir().full_path,
# Don't run any actual training, just initialize all classes.
n_epochs=0,
train_batch_size=1,
eval_batch_size=1,
)
except Exception as e:
raise AssertionError(
'Error in gin config {}.'.format(os.path.basename(config))
) from e

# Some tests, ex: DM suite can't be run in OSS - so skip them.
should_skip = False
try:
should_skip = should_skip or gin.query_parameter('RLTask.dm_suite')
except ValueError as e:
pass
try:
rl_trainer.train_rl(
output_dir=self.create_tempdir().full_path,
# Don't run any actual training, just initialize all classes.
n_epochs=0,
train_batch_size=1,
eval_batch_size=1,
)
except Exception as e:
raise AssertionError(
'Error in gin config {}.'.format(os.path.basename(config))
) from e
env_name = gin.query_parameter('RLTask.env')
should_skip = (should_skip or env_name.startswith('DM-') or
env_name.startswith('LunarLander'))
except ValueError as e:
pass

if should_skip:
pass
else:
run_config()


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion trax/rl/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def __init__(self, env=gin.REQUIRED,
if isinstance(env, str):
self._env_name = env
if dm_suite:
pass
eval_env = None
env = None
else:
env = gym.make(self._env_name)
eval_env = gym.make(self._env_name)
Expand Down

0 comments on commit 7397898

Please sign in to comment.