From aef24e319f1dae8686f5fb4562fba3ca7f4b0622 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 8 Jul 2020 23:07:53 +0530 Subject: [PATCH] =?UTF-8?q?and=20done=20=F0=9F=9A=80=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pipelines.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pipelines.py b/pipelines.py index 5feafe3..7750e46 100644 --- a/pipelines.py +++ b/pipelines.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) class QGPipeline: + """Poor man's QG pipeline""" def __init__( self, model: PreTrainedModel, @@ -234,13 +235,13 @@ def __call__(self, context: str, **generate_kwargs): input_length = inputs["input_ids"].shape[-1] - max_length = generate_kwargs.get("max_length", 256) - if input_length < max_length: - logger.warning( - "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( - max_length, input_length - ) - ) + # max_length = generate_kwargs.get("max_length", 256) + # if input_length < max_length: + # logger.warning( + # "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( + # max_length, input_length + # ) + # ) outs = self.model.generate( input_ids=inputs['input_ids'].cuda(),