diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 1c8a9b90..b8b265d4 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -103,8 +103,10 @@ def evaluate_sentence( batch_size=batch_size, verbose=True, ) - logits, offsets_mapping = logits[0], offsets_mapping[0] # FIXME - + logits = logits[0] + if offsets_mapping is not None: + offsets_mapping = offsets_mapping[0] + true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator) newline_labels = np.zeros(len(text)) newline_labels[true_end_indices - 1] = 1