-
Notifications
You must be signed in to change notification settings - Fork 0
/
LED.py
38 lines (32 loc) · 1.05 KB
/
LED.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from unml.models.model import Model
from unml.utils.consts.summarize import SummarizationConsts
class LED(Model):
"""
Class for LED (Beltagy et al., 2020) model
https://arxiv.org/pdf/2004.05150
"""
MODEL_NAME = "pszemraj/led-base-book-summary"
def __init__(self, modelName: str = MODEL_NAME) -> None:
super().__init__(modelName=modelName, task="summarization")
def summarize(
self,
text: str,
minLength: int = SummarizationConsts.SUMMARY_MIN_LENGTH,
maxLength: int = SummarizationConsts.SUMMARY_MAX_TOKEN_LENGTH,
doSample: bool = False,
) -> str:
"""
See doc for `Summarizer` class
"""
output = self.model(
text,
min_length=minLength,
max_length=maxLength,
do_sample=doSample,
no_repeat_ngram_size=3,
encoder_no_repeat_ngram_size=3,
repetition_penalty=3.5,
num_beams=4,
early_stopping=True,
)
return str(output[0]["summary_text"])