diff --git a/examples/research_projects/tools/python_interpreter.py b/examples/research_projects/tools/python_interpreter.py index 1b9d727c6f..8f319b2d68 100644 --- a/examples/research_projects/tools/python_interpreter.py +++ b/examples/research_projects/tools/python_interpreter.py @@ -154,7 +154,7 @@ def solution(): optimize_cuda_cache=True, ) -ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds) +ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds) test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader) # text env diff --git a/examples/research_projects/tools/triviaqa.py b/examples/research_projects/tools/triviaqa.py index bdf2c82287..def5013582 100644 --- a/examples/research_projects/tools/triviaqa.py +++ b/examples/research_projects/tools/triviaqa.py @@ -105,7 +105,7 @@ class ScriptArguments: seed=script_args.seed, optimize_cuda_cache=True, ) -ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer) +ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer) dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train") local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime dataset = dataset.shuffle(local_seed) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index e0dd07bb5a..b78286eb34 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -152,7 +152,7 @@ def tokenize(element): # Training ################ trainer = PPOTrainer( - config=training_args, + args=training_args, processing_class=tokenizer, policy=policy, ref_policy=ref_policy, diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 73ea5cd852..40fca6ff88 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -163,7 +163,7 @@ def tokenize(element): # Training ################ trainer = PPOTrainer( - config=training_args, + args=training_args, processing_class=tokenizer, policy=policy, ref_policy=ref_policy, diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 50f17ddece..3c3ab9504b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -51,20 +51,21 @@ from ..core import masked_mean, masked_whiten from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation -from ..trainer.utils import ( +from .ppo_config import PPOConfig +from .utils import ( OnlineTrainerState, batch_generation, disable_dropout_in_model, exact_div, first_true_indices, forward, + generate_model_card, get_reward, + peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, truncate_response, ) -from .ppo_config import PPOConfig -from .utils import generate_model_card, peft_module_casting_to_bf16 if is_peft_available(): @@ -97,10 +98,11 @@ def forward(self, **kwargs): class PPOTrainer(Trainer): _tag_names = ["trl", "ppo"] + @deprecate_kwarg("config", new_name="args", version="0.15.0", raise_if_both_names=True) @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) def __init__( self, - config: PPOConfig, + args: PPOConfig, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ], @@ -122,8 +124,7 @@ def __init__( "same as `policy`, you must make a copy of it, or `None` if you use peft." ) - self.args = config - args = config + self.args = args self.processing_class = processing_class self.policy = policy