-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing a custom batch sampler to the trainer #3162
base: master
Are you sure you want to change the base?
Conversation
@@ -77,21 +97,17 @@ def __init__( | |||
dataset: Dataset, | |||
batch_size: int, | |||
drop_last: bool, | |||
valid_label_columns: list[str] = None, | |||
generator: torch.Generator = None, | |||
seed: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there was an issue with the typing and usage of the seed parameter, given that 0 is falsy, it would be ignored if it was passed as a seed.
if self._batch_sampler: | ||
return self._batch_sampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the changes in the sampler.py
, but I'm not sure if this is the best option.
This here requires the user to initialize the Batch Sampler themselves, which prevents Sentence Transformers from updating the dataset(s). This breaks the prompts
feature here:
sentence-transformers/sentence_transformers/trainer.py
Lines 284 to 291 in c68bf68
if self.train_dataset is not None: | |
self.train_dataset = self.maybe_add_prompts_or_dataset_name_column( | |
train_dataset, args.prompts, dataset_name="train" | |
) | |
if self.eval_dataset is not None: | |
self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( | |
eval_dataset, args.prompts, dataset_name="eval" | |
) |
Additionally, it prevents multi-dataset training setups, because there's only one batch sampler possible.
So, I think a good solution is for the batch_sampler
argument to be 1) a class (not instance) that subclasses DefaultBatchSampler
or 2) a function that returns a subclass of DefaultBatchSampler
given dataset
, batch_size
, drop_last
, valid_label_columns
, generator
, seed
, and *args & **kwargs.
What do you think?
resolves #3152
Also, I believe there was an issue with the typing and usage of the seed parameter, given that 0 is falsy, it would be ignored if it was passed as a seed.
Missing: