diff --git a/pipelines.py b/pipelines.py index 5feafe3..7750e46 100644 --- a/pipelines.py +++ b/pipelines.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) class QGPipeline: + """Poor man's QG pipeline""" def __init__( self, model: PreTrainedModel, @@ -234,13 +235,13 @@ def __call__(self, context: str, **generate_kwargs): input_length = inputs["input_ids"].shape[-1] - max_length = generate_kwargs.get("max_length", 256) - if input_length < max_length: - logger.warning( - "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( - max_length, input_length - ) - ) + # max_length = generate_kwargs.get("max_length", 256) + # if input_length < max_length: + # logger.warning( + # "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( + # max_length, input_length + # ) + # ) outs = self.model.generate( input_ids=inputs['input_ids'].cuda(),