Skip to content

Commit

Permalink
support DatasetFormats.C: only positive pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Feb 6, 2024
1 parent f270cf4 commit 2b0af8d
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ class DatasetFormats:
"""
B = 'text,positive,negative'

"""
format C: text,positive
input format: [
text[0],
positive[0],
text[1],
positive[1],
...
]
"""
C = 'text,positive'

@classmethod
def list_formats(cls):
for key, val in DatasetFormats.__dict__.items():
Expand Down Expand Up @@ -414,15 +426,21 @@ def __call__(self, data: Dict) -> Dict:
elif 'text' in data and 'positive' in data and 'negative' in data:
self.dataset_format = DatasetFormats.B
logger.info(f'Detect DatasetFormats.B: {DatasetFormats.B}')
elif 'text' in data and 'positive' in data and 'negative' not in data and 'label' not in data:
self.dataset_format = DatasetFormats.C
logger.info(f'Detect DatasetFormats.C: {DatasetFormats.C}')
else:
raise NotImplementedError('Currently only support two dataset formats'
'DatasetFormats A: must include three columns: `text1`, `text2`, and `label`.'
'DatasetFormats B: mut include three columns: `text`, `positive`, `negative`')
'DatasetFormats B: mut include three columns: `text`, `positive`, `negative`'
'DatasetFormats C: mut include three columns: `text`, `positive`')
text_columns = None
if self.dataset_format == DatasetFormats.A:
text_columns = ['text1', 'text2']
elif self.dataset_format == DatasetFormats.B:
text_columns = ['text', 'positive', 'negative']
elif self.dataset_format == DatasetFormats.C:
text_columns = ['text', 'positive']

extra_length = 0
extra_placeholder = {}
Expand Down Expand Up @@ -720,6 +738,10 @@ def __call__(self,
loss += self.w2 * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau)
if self.w3 > 0:
loss += self.w3 * angle_loss(combined_labels, combined_inputs, self.angle_tau)
elif self.dataset_format == DatasetFormats.C:
text = outputs[::2]
positive = outputs[1::2]
loss = contrastive_with_negative_loss(text, positive, negative=None, tau=self.ibn_tau)
else:
raise NotImplementedError
return loss
Expand Down

0 comments on commit 2b0af8d

Please sign in to comment.