diff --git a/metrics/rouge/rouge.py b/metrics/rouge/rouge.py index 353301cca..4d4ba3a06 100644 --- a/metrics/rouge/rouge.py +++ b/metrics/rouge/rouge.py @@ -119,7 +119,7 @@ def _info(self): ) def _compute( - self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False, tokenizer=None + self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False, tokenizer=None, detailed=False ): if rouge_types is None: rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] @@ -148,11 +148,17 @@ def _compute( if use_aggregator: result = aggregator.aggregate() for key in result: - result[key] = result[key].mid.fmeasure + if not detailed: + result[key] = result[key].mid.fmeasure + else: + result[key] = result[key].mid else: result = {} for key in scores[0]: - result[key] = list(score[key].fmeasure for score in scores) + if not detailed: + result[key] = list(score[key].fmeasure for score in scores) + else: + result[key] = score[key] return result