Skip to content

Commit

Permalink
fix(ppo_trainer): default gen kwargs (#510)
Browse files Browse the repository at this point in the history
* fix(ppo_trainer): force `use_cache=True` by default

* fix(ppo_trainer): `batch_size` -> `chunk_size` for evaluation

* fix(base_trainer): force pad_token regardless of architecture
  • Loading branch information
maxreciprocate authored Jun 23, 2023
1 parent fbc9e04 commit 171357b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
5 changes: 2 additions & 3 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def __init__(self, config, **kwargs): # noqa: C901
self.tokenizer.padding_side = config.tokenizer.padding_side
self.tokenizer.truncation_side = config.tokenizer.truncation_side
self.tokenizer.sep_token = "<sep>"
if config.model.model_arch_type != "seq2seq":
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = "<|padding|>"

script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0]
if not isinstance(config.model.model_path, str):
Expand Down
42 changes: 13 additions & 29 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,34 +83,18 @@ def __init__(self, config: TRLConfig, **kwargs):
# Create the parameters for the Hugging Face language model's generator
# method (that generates new tokens from a prompt).
# https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate
if config.model.model_arch_type == "seq2seq":
self.generate_kwargs = dict(
config.method.gen_kwargs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
if config.method.gen_experience_kwargs is not None:
self.generate_experience_kwargs = dict(
config.method.gen_experience_kwargs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
self.generate_experience_kwargs = None
generate_kwargs = dict(
do_sample=True,
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs}

if config.method.gen_experience_kwargs is not None:
self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs}
else:
self.generate_kwargs = dict(
config.method.gen_kwargs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
)
if config.method.gen_experience_kwargs is not None:
self.generate_experience_kwargs = dict(
config.method.gen_experience_kwargs,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
)
else:
self.generate_experience_kwargs = None
self.generate_experience_kwargs = None

# Setup stats tracker
self.running_moments = RunningMoments()
Expand Down Expand Up @@ -236,12 +220,12 @@ def post_backward_callback(self):
self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size)

def prepare_learning(self):
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)
eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)
self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)

self.make_experience(self.config.method.num_rollouts)

self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True)
self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False)

self.n_updates_per_batch = self.config.method.ppo_epochs
self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader)
Expand Down

0 comments on commit 171357b

Please sign in to comment.