Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM: special multimodal Tokenizer #34461

Merged
merged 24 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/en/main_classes/tokenizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ token space (e.g., getting the index of the token comprising a given character o
to a given token).


# Multimodal Tokenizer

Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens
as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will
be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder.

To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not
have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access
to three more special tokens.

```python
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.extra_special_tokens = ["image_token", "boi_token", "eoi_token"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.extra_special_tokens = ["image_token", "boi_token", "eoi_token"]
tokenizer = AutoTokenizer.from_pretrained(model_id, extra_special_tokens = ["image_token", "boi_token", "eoi_token"])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a small test for this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is actually not correct anymore hehe, forgot to update the docs. And it has a test for that already so we are good

new way of adding extra special tokens is like
tokenizer.extra_special_tokens = {"eoi_token": "<s>", "image_token": "<image>"}. After adding this line and saving the tokenizer, loading back will do the magic and tokenizer will have self.image_token attribute

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to pass it as input as well instead of forcing people to use the setter! 🤗

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap, realized later and added that in the docs instead of "saving-loading back". Plus extended the test

tokenizer.save_pretrained(output_dir)

vision_tokenizer = AutoTokenizer.save_pretrained(output_dir)
vision_tokenizer.image_token = "IMAGE"
```
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

## PreTrainedTokenizer

[[autodoc]] PreTrainedTokenizer
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,6 @@ def __init__(self, **kwargs):
self.init_kwargs = copy.deepcopy(kwargs)
self.name_or_path = kwargs.pop("name_or_path", "")
self._processor_class = kwargs.pop("processor_class", None)
self.is_multimodal = kwargs.pop("is_multimodal", False)

# For backward compatibility we fallback to set model_max_length from max_len if provided
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
Expand Down Expand Up @@ -1440,9 +1439,8 @@ def __init__(self, **kwargs):

super().__init__(**kwargs)

if self.is_multimodal:
extra_special_tokens = ["image_token", "video_token", "boi_token", "eoi_token", "image_boundary_token"]
self._set_model_specific_special_tokens(special_tokens=extra_special_tokens)
self.extra_special_tokens = kwargs.pop("extra_special_tokens", [])
self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we do this, we don't add them to the tokenizer vocab right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are already checking that these tokens are added to the vocab if not already present right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the special token is not present in the vocab, we do add them as new tokens to the tokenizer vocab. Should we prevent users from adding new tokens and allow to use only available tokens?

It happens because the Tokenizer initially is wired to do that, irrespective of current changes

# 4. If some of the special tokens are not part of the vocab, we add them, at the end.
# the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO it's alright IMO we have not really seen reports about that


@property
def max_len_single_sentence(self) -> int:
Expand Down Expand Up @@ -2404,8 +2402,8 @@ def save_pretrained(

# Let's make sure we properly save the special tokens and flag whether it is a multimodal tokenizer.
tokenizer_config.update(self.special_tokens_map)
if self.is_multimodal and "is_multimodal" not in tokenizer_config:
tokenizer_config["is_multimodal"] = True
if "extra_special_tokens" not in tokenizer_config:
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens

if self.chat_template is not None:
if isinstance(self.chat_template, dict):
Expand Down
4 changes: 2 additions & 2 deletions tests/tokenization/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_decoding_single_token(self):
self.assertEqual(decoded_flat, "##:")
self.assertEqual(decoded_list, "##:")

def test_extra_sepcial_tokens_multimodal(self):
def test_extra_special_tokens_multimodal(self):
special_tokens_list = [
"bos_token",
"eos_token",
Expand All @@ -293,7 +293,7 @@ def test_extra_sepcial_tokens_multimodal(self):
"additional_special_tokens",
]
llama_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
llama_tokenizer.is_multimodal = True
llama_tokenizer.extra_special_tokens = ["image_token", "boi_token", "eoi_token"]
self.assertListEqual(llama_tokenizer.SPECIAL_TOKENS_ATTRIBUTES, special_tokens_list)
with tempfile.TemporaryDirectory() as tmpdirname:
llama_tokenizer.save_pretrained(tmpdirname)
Expand Down
Loading