Skip to content

Commit

Permalink
and done πŸš€πŸŽ‰
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jul 8, 2020
1 parent 70535a6 commit aef24e3
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger = logging.getLogger(__name__)

class QGPipeline:
"""Poor man's QG pipeline"""
def __init__(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -234,13 +235,13 @@ def __call__(self, context: str, **generate_kwargs):

input_length = inputs["input_ids"].shape[-1]

max_length = generate_kwargs.get("max_length", 256)
if input_length < max_length:
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
max_length, input_length
)
)
# max_length = generate_kwargs.get("max_length", 256)
# if input_length < max_length:
# logger.warning(
# "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
# max_length, input_length
# )
# )

outs = self.model.generate(
input_ids=inputs['input_ids'].cuda(),
Expand Down

0 comments on commit aef24e3

Please sign in to comment.