diff --git a/nemo_skills/training/data_preparation_utils/filters.py b/nemo_skills/training/data_preparation_utils/filters.py index 313ea4b63..95a2016ef 100644 --- a/nemo_skills/training/data_preparation_utils/filters.py +++ b/nemo_skills/training/data_preparation_utils/filters.py @@ -78,6 +78,26 @@ def _chunk_manifest(self): yield manifest_chunk +class DropIfRegexMatch(BaseFilter): + """Drops data if text matches a regex pattern.""" + + def __init__( + self, + regex_patterns: List[str], + text_key: str = "text", + **kwargs, + ): + super().__init__(**kwargs) + self.regex_patterns = regex_patterns + self.text_key = text_key + + def process_dataset_entry(self, data_entry) -> List: + for regex_pattern in self.regex_patterns: + if re.search(re.escape(regex_pattern), data_entry[self.text_key]): + return [DataEntry(data=None, metrics=dict(num_removed=1))] + return [DataEntry(data=data_entry, metrics=dict(num_reomoved=0))] + + class DropMultiBoxed(BaseFilter): def __init__(self, solution_key: str = "generation", **kwargs): super().__init__(**kwargs) diff --git a/nemo_skills/training/data_preparation_utils/prepare_sft_data.yaml b/nemo_skills/training/data_preparation_utils/prepare_sft_data.yaml index 3837c948e..e99633fc5 100644 --- a/nemo_skills/training/data_preparation_utils/prepare_sft_data.yaml +++ b/nemo_skills/training/data_preparation_utils/prepare_sft_data.yaml @@ -96,7 +96,7 @@ processors: should_run: ${filters.remove_contaminated} contamination_file: ${contamination_file} - - _target_: sdp.processors.DropIfRegexMatch + - _target_: nemo_skills.training.data_preparation_utils.filters.DropIfRegexMatch should_run: ${filters.remove_code_errors} text_key: ${output_key} regex_patterns: @@ -109,7 +109,7 @@ processors: - {input: {generation: "My solution:\nTimed out\nSomething else"}, output: null} - {input: {generation: "My solution, no errors"}, output: {generation: "My solution, no errors"}} - - _target_: sdp.processors.DropIfRegexMatch + - _target_: nemo_skills.training.data_preparation_utils.filters.DropIfRegexMatch should_run: ${filters.remove_verification_code} text_key: ${output_key} regex_patterns: