Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing a custom batch sampler to the trainer #3162

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

alonme
Copy link

@alonme alonme commented Jan 11, 2025

resolves #3152

Also, I believe there was an issue with the typing and usage of the seed parameter, given that 0 is falsy, it would be ignored if it was passed as a seed.

Missing:

  1. Documentation - wanted to make sure this makes sense before writing docs

@@ -77,21 +97,17 @@ def __init__(
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] = None,
generator: torch.Generator = None,
seed: int = 0,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there was an issue with the typing and usage of the seed parameter, given that 0 is falsy, it would be ignored if it was passed as a seed.

Comment on lines +587 to +588
if self._batch_sampler:
return self._batch_sampler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the changes in the sampler.py, but I'm not sure if this is the best option.
This here requires the user to initialize the Batch Sampler themselves, which prevents Sentence Transformers from updating the dataset(s). This breaks the prompts feature here:

if self.train_dataset is not None:
self.train_dataset = self.maybe_add_prompts_or_dataset_name_column(
train_dataset, args.prompts, dataset_name="train"
)
if self.eval_dataset is not None:
self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column(
eval_dataset, args.prompts, dataset_name="eval"
)

Additionally, it prevents multi-dataset training setups, because there's only one batch sampler possible.

So, I think a good solution is for the batch_sampler argument to be 1) a class (not instance) that subclasses DefaultBatchSampler or 2) a function that returns a subclass of DefaultBatchSampler given dataset, batch_size, drop_last, valid_label_columns, generator, seed, and *args & **kwargs.

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve API to use a custom batch_sampler in a SentenceTransformerTrainer
2 participants