From c3143832cb305139b2551af2e00f008b4d64a981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 17 Sep 2024 12:09:16 +0200 Subject: [PATCH] `processor(prompt, images=image)` to `processor(images=image, text=prompt)` (#2076) * `prompt, images=image` to `images=image, text=prompt` * special case of model being str in BCO --- trl/trainer/bco_trainer.py | 2 +- trl/trainer/dpo_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 44a4ca03a2..c4ed97d3de 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -336,7 +336,7 @@ def __init__( if type(args) is TrainingArguments: raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") - if ref_model is model: + if not isinstance(model, str) and ref_model is model: raise ValueError( "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " "same as `model`, you must mass a copy of it, or `None` if you use peft." diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2189580c2e..3c9ff4624b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -123,7 +123,7 @@ def _process_prompt( ) prompt_tokens = [] for prompt, image in zip(prompts, images): - tokens = processor(prompt, images=image, **processor_kwargs) + tokens = processor(images=image, text=prompt, **processor_kwargs) tokens = {k: v[0] for k, v in tokens.items()} if not isinstance(tokens["input_ids"], list): tokens["input_ids"] = tokens["input_ids"].tolist() @@ -302,7 +302,7 @@ def tokenize(text, images=None): if "add_special_tokens" in inspect.signature(processor).parameters else {} ) - tokenized = processor(text, images=images, **processor_kwargs) + tokenized = processor(images=images, text=text, **processor_kwargs) tokenized = {k: v[0] for k, v in tokenized.items()} if not isinstance(tokenized["input_ids"], list): tokenized["input_ids"] = tokenized["input_ids"].tolist()