diff --git a/mimir/attacks/reference.py b/mimir/attacks/reference.py index 077d05e..6ea87ec 100644 --- a/mimir/attacks/reference.py +++ b/mimir/attacks/reference.py @@ -19,4 +19,4 @@ def _attack(self, document, probs, tokens=None, **kwargs): if loss is None: loss = self.model.get_ll(document, probs=probs, tokens=tokens) ref_loss = self.ref_model.get_ll(document, probs=probs, tokens=tokens) - return ref_loss - loss + return loss - ref_loss