From 2b0af8d5102a0659ef2521f176acbc53e1b9c574 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 6 Feb 2024 13:13:17 +0800 Subject: [PATCH] support DatasetFormats.C: only positive pairs --- angle_emb/angle.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6da8ed8..d05924a 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -336,6 +336,18 @@ class DatasetFormats: """ B = 'text,positive,negative' + """ + format C: text,positive + input format: [ + text[0], + positive[0], + text[1], + positive[1], + ... + ] + """ + C = 'text,positive' + @classmethod def list_formats(cls): for key, val in DatasetFormats.__dict__.items(): @@ -414,15 +426,21 @@ def __call__(self, data: Dict) -> Dict: elif 'text' in data and 'positive' in data and 'negative' in data: self.dataset_format = DatasetFormats.B logger.info(f'Detect DatasetFormats.B: {DatasetFormats.B}') + elif 'text' in data and 'positive' in data and 'negative' not in data and 'label' not in data: + self.dataset_format = DatasetFormats.C + logger.info(f'Detect DatasetFormats.C: {DatasetFormats.C}') else: raise NotImplementedError('Currently only support two dataset formats' 'DatasetFormats A: must include three columns: `text1`, `text2`, and `label`.' - 'DatasetFormats B: mut include three columns: `text`, `positive`, `negative`') + 'DatasetFormats B: mut include three columns: `text`, `positive`, `negative`' + 'DatasetFormats C: mut include three columns: `text`, `positive`') text_columns = None if self.dataset_format == DatasetFormats.A: text_columns = ['text1', 'text2'] elif self.dataset_format == DatasetFormats.B: text_columns = ['text', 'positive', 'negative'] + elif self.dataset_format == DatasetFormats.C: + text_columns = ['text', 'positive'] extra_length = 0 extra_placeholder = {} @@ -720,6 +738,10 @@ def __call__(self, loss += self.w2 * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau) if self.w3 > 0: loss += self.w3 * angle_loss(combined_labels, combined_inputs, self.angle_tau) + elif self.dataset_format == DatasetFormats.C: + text = outputs[::2] + positive = outputs[1::2] + loss = contrastive_with_negative_loss(text, positive, negative=None, tau=self.ibn_tau) else: raise NotImplementedError return loss