Skip to content

Commit

Permalink
[fix] Use HfArgumentParser-compatible typing for prompts (#3178)
Browse files Browse the repository at this point in the history
* Use HfArgumentParser-compatible typing for prompts

* Add a simple test case

* Use typing.Optional and Union for Python 3.9
  • Loading branch information
tomaarsen authored Jan 17, 2025
1 parent b2a5061 commit c68bf68
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
14 changes: 11 additions & 3 deletions sentence_transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from dataclasses import dataclass, field
from typing import Optional, Union

from transformers import TrainingArguments as TransformersTrainingArguments
from transformers.training_args import ParallelMode
Expand Down Expand Up @@ -170,11 +171,18 @@ class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
"""

prompts: dict[str, dict[str, str]] | dict[str, str] | str | None = None
batch_sampler: BatchSamplers | str = field(
prompts: Optional[str] = field( # noqa: UP007
default=None,
metadata={
"help": "The prompts to use for each column in the datasets. "
"Either 1) a single string prompt, 2) a mapping of column names to prompts, 3) a mapping of dataset names "
"to prompts, or 4) a mapping of dataset names to a mapping of column names to prompts."
},
)
batch_sampler: Union[BatchSamplers, str] = field( # noqa: UP007
default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
)
multi_dataset_batch_sampler: MultiDatasetBatchSamplers | str = field(
multi_dataset_batch_sampler: Union[MultiDatasetBatchSamplers, str] = field( # noqa: UP007
default=MultiDatasetBatchSamplers.PROPORTIONAL, metadata={"help": "The multi-dataset batch sampler to use."}
)

Expand Down
27 changes: 27 additions & 0 deletions tests/test_training_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import json

from transformers import HfArgumentParser

from sentence_transformers import SentenceTransformerTrainingArguments


def test_hf_argument_parser():
# See https://github.com/UKPLab/sentence-transformers/issues/3090;
# Ensure that the HfArgumentParser can be used to parse SentenceTransformerTrainingArguments.
parser = HfArgumentParser(SentenceTransformerTrainingArguments)
args = parser.parse_args(
args=[
"--output_dir",
"test_output_dir",
"--prompts",
'{"query_column": "query_prompt", "positive_column": "positive_prompt", "negative_column": "negative_prompt"}',
]
)
assert args.output_dir == "test_output_dir"
assert json.loads(args.prompts) == {
"query_column": "query_prompt",
"positive_column": "positive_prompt",
"negative_column": "negative_prompt",
}

0 comments on commit c68bf68

Please sign in to comment.