Skip to content

Commit

Permalink
Add return_full_text option to generate
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcin Kardas committed Feb 13, 2023
1 parent e3e3448 commit 7ee7c97
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion galai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def generate(
penalty_alpha=None,
num_beams=1,
num_return_sequences=1,
return_full_text=True,
) -> Union[str, List[str], List[List[str]]]:
"""
Generates text using the model
Expand Down Expand Up @@ -301,8 +302,11 @@ def generate(
**options
)

out_tokens = out['sequences']
if not return_full_text:
out_tokens = out_tokens[:, input_v.shape[1]:]
# we keep special tokens such as [START_REF] or <work>
decoded = self.tokenizer.batch_decode(out['sequences'], skip_special_tokens=False)
decoded = self.tokenizer.batch_decode(out_tokens, skip_special_tokens=False)
# so we manually remove </s> and <pad>
decoded = [
text.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages

PACKAGE_NAME = 'galai'
VERSION = "1.1.4"
VERSION = "1.1.5"
DESCRIPTION = "API for the GALACTICA model"
KEYWORDS = "Scientific Intelligence"
URL = 'https://github.com/paperswithcode/galai'
Expand Down

0 comments on commit 7ee7c97

Please sign in to comment.