-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VLM: special multimodal Tokenizer #34461
Changes from 1 commit
d4778c9
0ba73ec
116a49e
661b5df
5ed9379
a284f31
9ad5362
1def73c
8dfa536
200879b
e0bf53b
7e5c4ba
b77f3be
8b61969
276c55e
3177519
46a30e5
d7c4eb5
f3b102e
5240129
58bf9e7
76b39aa
9994aed
c21997d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1400,7 +1400,6 @@ def __init__(self, **kwargs): | |
self.init_kwargs = copy.deepcopy(kwargs) | ||
self.name_or_path = kwargs.pop("name_or_path", "") | ||
self._processor_class = kwargs.pop("processor_class", None) | ||
self.is_multimodal = kwargs.pop("is_multimodal", False) | ||
|
||
# For backward compatibility we fallback to set model_max_length from max_len if provided | ||
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) | ||
|
@@ -1440,9 +1439,8 @@ def __init__(self, **kwargs): | |
|
||
super().__init__(**kwargs) | ||
|
||
if self.is_multimodal: | ||
extra_special_tokens = ["image_token", "video_token", "boi_token", "eoi_token", "image_boundary_token"] | ||
self._set_model_specific_special_tokens(special_tokens=extra_special_tokens) | ||
self.extra_special_tokens = kwargs.pop("extra_special_tokens", []) | ||
self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when we do this, we don't add them to the tokenizer vocab right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are already checking that these tokens are added to the vocab if not already present right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the special token is not present in the vocab, we do add them as new tokens to the tokenizer vocab. Should we prevent users from adding new tokens and allow to use only available tokens? It happens because the Tokenizer initially is wired to do that, irrespective of current changes
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NO it's alright IMO we have not really seen reports about that |
||
|
||
@property | ||
def max_len_single_sentence(self) -> int: | ||
|
@@ -2404,8 +2402,8 @@ def save_pretrained( | |
|
||
# Let's make sure we properly save the special tokens and flag whether it is a multimodal tokenizer. | ||
tokenizer_config.update(self.special_tokens_map) | ||
if self.is_multimodal and "is_multimodal" not in tokenizer_config: | ||
tokenizer_config["is_multimodal"] = True | ||
if "extra_special_tokens" not in tokenizer_config: | ||
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens | ||
|
||
if self.chat_template is not None: | ||
if isinstance(self.chat_template, dict): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a small test for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is actually not correct anymore hehe, forgot to update the docs. And it has a test for that already so we are good
new way of adding extra special tokens is like
tokenizer.extra_special_tokens = {"eoi_token": "<s>", "image_token": "<image>"}
. After adding this line and saving the tokenizer, loading back will do the magic and tokenizer will haveself.image_token
attributeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be able to pass it as input as well instead of forcing people to use the setter! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap, realized later and added that in the docs instead of "saving-loading back". Plus extended the test