-
Notifications
You must be signed in to change notification settings - Fork 4
/
prumerge_llava_next.py
1239 lines (1039 loc) · 61.9 KB
/
prumerge_llava_next.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# =================================== PruMerge and PruMerge+ ======================================
# Paste this code into src/transformers/models/llava_next/modeling_llava_next.py to use PruMerge and PruMerge+
# =====================================================================================================
# Credits:
# Code is copied from: https://github.com/42Shawn/LLaVA-PruMerge
# paper: https://arxiv.org/abs/2403.15388
import math, time
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_llava_next import LlavaNextConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextConfig"
# =================================== PruMerge and PruMerge+ HELPER ======================================
# for prumerge
outputs = {}
def hook_k(module, input, output):
outputs['desired_k'] = output
def hook_q(module, input, output):
outputs['desired_q'] = output
def complement_idx(idx, dim):
a = torch.arange(dim, device=idx.device)
ndim = idx.ndim
dims = idx.shape
n_idx = dims[-1]
dims = dims[:-1] + (-1, )
for i in range(1, ndim):
a = a.unsqueeze(0)
a = a.expand(*dims)
masked = torch.scatter(a, -1, idx, 0)
compl, _ = torch.sort(masked, dim=-1, descending=False)
compl = compl.permute(-1, *tuple(range(ndim - 1)))
compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
return compl
def outlier_dectection(attn):
attn_np = attn.to(dtype=torch.float32).cpu().numpy().flatten()
Q1 = np.percentile(attn_np, 25)
Q3 = np.percentile(attn_np, 75)
IQR = Q3 - Q1
# lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outlier_indices = np.where((attn_np > upper_bound))[0]
ratio = len(outlier_indices) / len(attn_np)
return ratio
# =================================== PruMerge and PruMerge+ HELPER END ======================================
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise ValueError(
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
int: the number of patches
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise ValueError(f"image_size invalid type {type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image or mask.
A simple hack on top of the original implementation so that this
function can handle both images and masks.
Args:
tensor (`torch.Tensor`):
The image or mask tensor, assumed to be of shape (num_channels, height, width) for images or (height, width) for masks.
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image or mask tensor.
"""
original_height, original_width = original_size
if tensor.ndim == 3:
current_height, current_width = tensor.shape[1:]
else:
current_height, current_width = tensor.shape
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
if tensor.ndim == 3:
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
unpadded_tensor = tensor[padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
if tensor.ndim == 3:
unpadded_tensor = tensor[:, :, padding : current_width - padding]
else:
unpadded_tensor = tensor[:, padding : current_width - padding]
return unpadded_tensor
@dataclass
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaNext
class LlavaNextCausalLMOutputWithPast(ModelOutput):
"""
Base class for LlavaNext causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
LLAVA_NEXT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_START_DOCSTRING,
)
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next
class LlavaNextPreTrainedModel(PreTrainedModel):
config_class = LlavaNextConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_NEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
[`LlavaNextImageProcessor`] for processing images.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_START_DOCSTRING,
)
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.vit_to_llm_mapping = None
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(
self,
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
labels=None,
image_token_index=None,
ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
Args:
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
All vision vectors of all images in the batch
feature_lens (`torch.LongTensor` of shape `(num_images)`):
The length of visual embeddings of each image as stacked in `image_features`
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with visual embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with image token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
```python
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
prompts = [
"[INST] <image>\nWhat is shown in this image? [/INST]",
"[INST] <image>\nWhat is shown in this image? [/INST]",
]
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
chart_img has 2634 tokens, while cat_img has 2340 tokens
```
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
with torch.no_grad():
# ! in llava 1.6, number of patches is variable
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
# both side is 1, so cannot tell
left_padding = self.padding_side == "left"
else:
# invalid attention_mask
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_labels = None
if labels is not None:
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# Calculate the image feature indices in the final embedding
image_indices = []
for i in range(batch_size):
image_positions = torch.where(image_to_overwrite[i])[0]
image_indices.append(image_positions)
return final_embedding, final_attention_mask, position_ids, final_labels, image_indices
def pack_image_features(self, image_features, image_sizes, image_newline=None, image_features_masks=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`):
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`):
Actual image size of each image (H, W).
image_newline (`torch.Tensor` of shape `(embed_dim)`):
New line embedding vector.
image_features_masks Tuple of (`torch.Tensor` of shape `(num_patches, num_tokens)`):
Mask for selecting tokens.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`):
Concatenated image features.
feature_lens (`List[int]`):
Token length of each image in image_features.
"""
new_image_features = []
feature_lens = []
for image_idx, (image_feature, image_features_mask) in enumerate(zip(image_features, image_features_masks)):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0] # [576, 4096]
base_image_mask = image_features_mask[0] # [576]
reduced_base_image_feature = base_image_feature[base_image_mask.bool()] # [reduced, 4096]
image_feature = image_feature[1:] # [4, 576, 4096]
other_images_mask = image_features_mask[1:] # [4, 576]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size # 24
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
) # (2, 2)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # [2, 2, 24, 24, 4096]
other_images_mask = other_images_mask.view(num_patch_height, num_patch_width, height, width) # [2, 2, 24, 24]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() # [4096, 2, 24, 2, 24]
other_images_mask = other_images_mask.permute(0, 2, 1, 3).contiguous() # [2, 24, 2, 24]
image_feature = image_feature.flatten(1, 2).flatten(2, 3) # [4096, 48, 48]
other_images_mask = other_images_mask.flatten(0, 1).flatten(1, 2) # [48, 48]
image_feature = unpad_image(image_feature, image_sizes[image_idx]) # [4096, 48, 48] assuming no unpadding
other_images_mask = unpad_image(other_images_mask, image_sizes[image_idx]) # [48, 48] assuming no unpadding
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
),
dim=-1,
) # [4096, 48, 49]
other_images_mask = torch.cat(
(
other_images_mask,
torch.ones(other_images_mask.shape[0], 1, device=other_images_mask.device), # [42, 1]
),
dim=-1,
) # [48,49]
# applying mask to reduce
image_feature = image_feature.flatten(1, 2) # [4096, 48*49]
other_images_mask = other_images_mask.flatten() # [48*49]
reduced_image_feature = image_feature[:, other_images_mask.bool()] # [4096, reduced]
reduced_image_feature = reduced_image_feature.transpose(0, 1) # [reduced, 4096]
image_feature = torch.cat((reduced_base_image_feature, reduced_image_feature), dim=0) # [reduced+reduced, 4096]
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
# ========================= PruMerge+ =========================
def token_prune_merge_advanced_plus(self, images, if_adaptive=True, reduction_ratio = 0.2):
'''
version 24/03/2024 using the spacially smapled tokens to supplement the pruned tokens
'''
# token_indix_list = []
# token_indix_dict = {}
#set hooks for extracting desired layer's k and q
#set hooks for extracting desired layer's k and q
hook_handle_k = self.vision_tower.vision_model.encoder.layers[23].self_attn.k_proj.register_forward_hook(hook_k)
hook_handle_q = self.vision_tower.vision_model.encoder.layers[23].self_attn.q_proj.register_forward_hook(hook_q)
#forward pass
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[self.config.vision_feature_layer][:, 1:]
B, N, C = image_features.shape
#extract desired layer's k and q and remove hooks; calculate attention
desired_layer_k = outputs["desired_k"]
desired_layer_q = outputs["desired_q"]
hook_handle_k.remove()
hook_handle_q.remove()
attn = (desired_layer_q @ desired_layer_k.transpose(-2, -1)) * C ** -0.5
attn = F.softmax(attn, dim=-1)
cls_attn = attn[:, 0, 1:]
if if_adaptive:
reduction_ratio = outlier_dectection(cls_attn)#*3.5
_, idx = torch.topk(cls_attn, int(N*reduction_ratio), dim=1, largest=True) # [B, left_tokens] , sorted=True
# # # print("idx: ", idx)
if if_adaptive:
step_length = int(1/reduction_ratio)
arithmetic_sequence = torch.arange(0, 575, int(step_length/3)).to(device=self.device)
original_tensor_1d = idx.flatten().to(device=self.device)
filtered_sequence = torch.tensor([x for x in arithmetic_sequence if x not in original_tensor_1d]).to(device=self.device)
concatenated_tensor = torch.cat((idx, filtered_sequence.unsqueeze(0)), dim=1)
idx = concatenated_tensor
# # print("idx_new: ", idx)
else:
# # this is for training
step_length = int(1/reduction_ratio)
new_idx = torch.zeros((idx.size(0), idx.size(1)*2), dtype=torch.long).to(device=self.device)
for i in range(idx.size(0)):
arithmetic_sequence = torch.arange(int(step_length/2), 575, int(step_length)).to(device=self.device)
original_tensor_1d = idx[i].flatten().to(device=self.device)
filtered_sequence = arithmetic_sequence
# filtered_sequence = torch.tensor([x for x in arithmetic_sequence if x not in original_tensor_1d]).to(device=self.device)
concatenated_tensor = torch.cat((original_tensor_1d, filtered_sequence), dim=0)
new_idx[i] = concatenated_tensor
idx = new_idx
index = idx.unsqueeze(-1).expand(-1, -1, C) # [B, left_tokens, C]
Key_wo_cls = desired_layer_k[:, 1:] # [B, N-1, C]
x_others = torch.gather(image_features, dim=1, index=index) # [B, left_tokens, C]
x_others_attn = torch.gather(cls_attn, dim=1, index=idx)
Key_others = torch.gather(Key_wo_cls, dim=1, index=index) # [B, left_tokens, C]
compl = complement_idx(idx, N) # [B, N-1-left_tokens]
non_topk = torch.gather(image_features, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C)) # [B, N-1-left_tokens, C]
non_topk_Key = torch.gather(Key_wo_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C))
non_topk_attn = torch.gather(cls_attn, dim=1, index=compl) # [B, N-1-left_tokens]
Key_others_norm = F.normalize(Key_others, p=2, dim=-1)
non_topk_Key_norm = F.normalize(non_topk_Key, p=2, dim=-1)
# cos_sim = torch.bmm(Key_others_norm, non_topk_Key_norm.transpose(1, 2)) # [B, left_tokens, N-1-left_tokens]
# _, cluster_indices = torch.topk(cos_sim, k=4, dim=2, largest=True)
B, left_tokens, C = x_others.size()
updated_x_others = torch.zeros_like(x_others)
for b in range(B):
for i in range(left_tokens):
key_others_norm = Key_others_norm[b,i,:].unsqueeze(0).unsqueeze(0)
before_i_Key = Key_others_norm[b, :i, :].unsqueeze(0)
after_i_Key = Key_others_norm[b, i+1:, :].unsqueeze(0)
before_i_x_others = x_others[b, :i, :].unsqueeze(0)
after_i_x_others = x_others[b, i+1:, :].unsqueeze(0)
rest_x_others = torch.cat([before_i_x_others, after_i_x_others, non_topk[b,:,:].unsqueeze(0)], dim=1)
before_i_x_others_attn = x_others_attn[b, :i].unsqueeze(0)
after_i_x_others_attn = x_others_attn[b, i+1:].unsqueeze(0)
rest_x_others_attn = torch.cat([before_i_x_others_attn, after_i_x_others_attn, non_topk_attn[b,:].unsqueeze(0)], dim=1)
rest_Keys = torch.cat([before_i_Key, after_i_Key, non_topk_Key_norm[b,:,:].unsqueeze(0)], dim=1)
cos_sim_matrix = torch.bmm(key_others_norm, rest_Keys.transpose(1, 2))
_, cluster_indices = torch.topk(cos_sim_matrix, k=int(32), dim=2, largest=True)
cluster_tokens = rest_x_others[:,cluster_indices.squeeze(),:]
weights = rest_x_others_attn[:,cluster_indices.squeeze()].unsqueeze(-1)
# update cluster centers
weighted_avg = torch.sum(cluster_tokens * weights, dim=1) #/ torch.sum(weights)
updated_center = x_others[b, i, :] + weighted_avg
updated_x_others[b, i, :] = updated_center
extra_one_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True) # [B, 1, C]
updated_x_others = torch.cat([updated_x_others, extra_one_token],dim=1)
image_features = updated_x_others
return image_features
# ========================= PruMerge =========================
def token_prune_merge_advanced(self, images, if_adaptive=True, reduction_ratio = 0.2):
'''
version 10/03/2024 using the key*key matrix to calculate the cosine similarity
'''
# token_indix_list = []
# token_indix_dict = {}
#set hooks for extracting desired layer's k and q
hook_handle_k = self.vision_tower.vision_model.encoder.layers[23].self_attn.k_proj.register_forward_hook(hook_k)
hook_handle_q = self.vision_tower.vision_model.encoder.layers[23].self_attn.q_proj.register_forward_hook(hook_q)
#forward pass
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[self.config.vision_feature_layer][:, 1:]
B, N, C = image_features.shape
#extract desired layer's k and q and remove hooks; calculate attention
desired_layer_k = outputs["desired_k"]
desired_layer_q = outputs["desired_q"]
hook_handle_k.remove()
hook_handle_q.remove()
attn = (desired_layer_q @ desired_layer_k.transpose(-2, -1)) * C ** -0.5
attn = F.softmax(attn, dim=-1)
cls_attn = attn[:, 0, 1:]
if if_adaptive:
reduction_ratio = outlier_dectection(cls_attn)#*3.5
_, idx = torch.topk(cls_attn, int(N*reduction_ratio), dim=1, largest=True) # [B, left_tokens] , sorted=True
index = idx.unsqueeze(-1).expand(-1, -1, C) # [B, left_tokens, C]
Key_wo_cls = desired_layer_k[:, 1:] # [B, N-1, C]
x_others = torch.gather(image_features, dim=1, index=index) # [B, left_tokens, C]
x_others_attn = torch.gather(cls_attn, dim=1, index=idx)
Key_others = torch.gather(Key_wo_cls, dim=1, index=index) # [B, left_tokens, C]
compl = complement_idx(idx, N) # [B, N-1-left_tokens]
non_topk = torch.gather(image_features, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C)) # [B, N-1-left_tokens, C]
non_topk_Key = torch.gather(Key_wo_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C))
non_topk_attn = torch.gather(cls_attn, dim=1, index=compl) # [B, N-1-left_tokens]
Key_others_norm = F.normalize(Key_others, p=2, dim=-1)
non_topk_Key_norm = F.normalize(non_topk_Key, p=2, dim=-1)
# cos_sim = torch.bmm(Key_others_norm, non_topk_Key_norm.transpose(1, 2)) # [B, left_tokens, N-1-left_tokens]
# _, cluster_indices = torch.topk(cos_sim, k=4, dim=2, largest=True)
B, left_tokens, C = x_others.size()
updated_x_others = torch.zeros_like(x_others)
for b in range(B):
for i in range(left_tokens):
key_others_norm = Key_others_norm[b,i,:].unsqueeze(0).unsqueeze(0)
before_i_Key = Key_others_norm[b, :i, :].unsqueeze(0)
after_i_Key = Key_others_norm[b, i+1:, :].unsqueeze(0)
before_i_x_others = x_others[b, :i, :].unsqueeze(0)
after_i_x_others = x_others[b, i+1:, :].unsqueeze(0)
rest_x_others = torch.cat([before_i_x_others, after_i_x_others, non_topk[b,:,:].unsqueeze(0)], dim=1)
before_i_x_others_attn = x_others_attn[b, :i].unsqueeze(0)
after_i_x_others_attn = x_others_attn[b, i+1:].unsqueeze(0)
rest_x_others_attn = torch.cat([before_i_x_others_attn, after_i_x_others_attn, non_topk_attn[b,:].unsqueeze(0)], dim=1)
rest_Keys = torch.cat([before_i_Key, after_i_Key, non_topk_Key_norm[b,:,:].unsqueeze(0)], dim=1)
cos_sim_matrix = torch.bmm(key_others_norm, rest_Keys.transpose(1, 2))
_, cluster_indices = torch.topk(cos_sim_matrix, k=int(32), dim=2, largest=True)
cluster_tokens = rest_x_others[:,cluster_indices.squeeze(),:]
weights = rest_x_others_attn[:,cluster_indices.squeeze()].unsqueeze(-1)
# update cluster centers
weighted_avg = torch.sum(cluster_tokens * weights, dim=1) #/ torch.sum(weights)
updated_center = weighted_avg + x_others[b, i, :]
updated_x_others[b, i, :] = updated_center
extra_one_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True) # [B, 1, C]
updated_x_others = torch.cat([updated_x_others, extra_one_token],dim=1)
image_features = updated_x_others
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image