forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
weight.py
1015 lines (907 loc) · 45 KB
/
weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import configparser
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from safetensors import safe_open
import tensorrt_llm
import tensorrt_llm.logger as logger
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import BaichuanForCausalLM
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
from tensorrt_llm.quantization import QuantMode
def get_scaling_factors(
model_path: Union[str, Path],
num_layers: int,
quant_mode: Optional[QuantMode] = None,
) -> Optional[Dict[str, List[int]]]:
""" Get the scaling factors for Baichuan model
Returns a dictionary of scaling factors for the selected layers of the
Baichuan model.
Args:
model_path (str): Path to the quantized Baichuan model
layers (list): List of layers to get the scaling factors for. If None,
all layers are selected.
Returns:
dict: Dictionary of scaling factors for the selected layers of the
Baichuan model.
example:
{
'qkv_act': qkv_act_scale,
'qkv_weights': qkv_weights_scale,
'qkv_output' : qkv_outputs_scale,
'dense_act': dense_act_scale,
'dense_weights': dense_weights_scale,
'fc_act': fc_act_scale,
'fc_weights': fc_weights_scale,
'gate_act': gate_act_scale,
'gate_weights': gate_weights_scale,
'proj_act': proj_act_scale,
'proj_weights': proj_weights_scale,
}
"""
if model_path is None:
logger.warning(f"--quantized_fp8_model_path not specified. "
f"Initialize quantization scales automatically.")
return get_dummy_quant_scales(num_layers)
weight_dict = np.load(model_path)
# yapf: disable
scaling_factor = {
'qkv_act': [],
'qkv_weights': [],
'qkv_output': [],
'dense_act': [],
'dense_weights': [],
'fc_act': [],
'fc_weights': [],
'gate_act': [],
'gate_weights': [],
'proj_act': [],
'proj_weights': [],
}
for layer in range(num_layers):
scaling_factor['qkv_act'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
))
scaling_factor['qkv_weights'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
))
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
# Not calibrarting KV cache.
scaling_factor['qkv_output'].append(1.0)
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
# yapf: enable
for k, v in scaling_factor.items():
assert len(v) == num_layers, \
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
return scaling_factor
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return s
return None
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return np.ascontiguousarray(np.split(v, tp_size)[idx].copy())
else:
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy())
def load_from_hf_baichuan(tensorrt_llm_baichuan,
hf_baichuan,
model_version,
rank=0,
tensor_parallel=1,
dtype="float32"):
assert model_version is not None
tensorrt_llm.logger.info(
f'Loading weights from HF Baichuan {model_version}...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_baichuan, 'quant_mode', QuantMode(0))
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
use_weight_only = quant_mode.is_weight_only()
model_params = dict(hf_baichuan.named_parameters())
for k, v in model_params.items():
torch_dtype = str_dtype_to_torch(dtype)
v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
if 'model.embed_tokens.weight' in k:
tensorrt_llm_baichuan.vocab_embedding.weight.value = v
elif 'model.norm.weight' in k:
tensorrt_llm_baichuan.ln_f.weight.value = v
elif 'lm_head.weight' in k:
if model_version.startswith('v2'):
# baichuan v2 models use NormHead
tensorrt_llm.logger.info(
f'Normalizing lm_head.weight for {model_version}')
original_v = model_params[k]
v = torch_to_numpy(
torch.nn.functional.normalize(original_v).to(
torch_dtype).detach().cpu())
tensorrt_llm_baichuan.lm_head.weight.value = np.ascontiguousarray(
split(v, tensor_parallel, rank))
else:
layer_idx = extract_layer_idx(k)
if layer_idx is None:
continue
idx = int(layer_idx)
if idx >= tensorrt_llm_baichuan.num_layers:
continue
if 'input_layernorm.weight' in k:
tensorrt_llm_baichuan.layers[
idx].input_layernorm.weight.value = v
elif 'post_attention_layernorm.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].post_layernorm.weight
dst.value = v
elif 'self_attn.W_pack.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].attention.qkv.weight
q_emb = v.shape[0] // 3
model_emb = v.shape[1]
v = v.reshape(3, q_emb, model_emb)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (q_emb // tensor_parallel),
model_emb)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
idx].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'self_attn.o_proj.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].attention.dense.weight
split_v = split(v, tensor_parallel, rank, dim=1)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
idx].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.up_proj.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].mlp.gate.weight
split_v = split(v, tensor_parallel, rank, dim=0)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
idx].mlp.gate.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.down_proj.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].mlp.proj.weight
split_v = split(v, tensor_parallel, rank, dim=1)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
idx].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
elif 'mlp.gate_proj.weight' in k:
dst = tensorrt_llm_baichuan.layers[idx].mlp.fc.weight
split_v = split(v, tensor_parallel, rank, dim=0)
if use_weight_only:
v = np.ascontiguousarray(split_v.transpose())
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(v), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
idx].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(split_v)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def parse_bin_config(ini_file):
baichuan_config = configparser.ConfigParser()
baichuan_config.read(ini_file)
n_embd = baichuan_config.getint('baichuan', 'hidden_size')
n_head = baichuan_config.getint('baichuan', 'num_attention_heads')
n_kv_head = n_head
n_layer = baichuan_config.getint('baichuan', 'num_hidden_layers')
if baichuan_config.has_option('baichuan', 'max_position_embeddings'):
n_positions = baichuan_config.getint('baichuan',
'max_position_embeddings')
else:
n_positions = baichuan_config.getint('baichuan', 'model_max_length')
vocab_size = baichuan_config.getint('baichuan', 'vocab_size')
hidden_act = baichuan_config.get('baichuan', 'hidden_act')
inter_size = baichuan_config.getint('baichuan',
'intermediate_size',
fallback=None)
if inter_size is None:
inter_size = 4 * n_embd
return n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
suffix = f"{rank}.bin"
if use_smooth_quant:
sq_prefix = "int8."
if quant_per_channel:
sq_prefix += "col."
suffix = sq_prefix + suffix
return suffix
def load_from_binary(tensorrt_llm_baichuan: BaichuanForCausalLM,
dir_path,
model_version,
mapping=Mapping(),
fp16=False,
multi_query_mode=False):
tensorrt_llm.logger.info('Loading weights from binary...')
tik = time.time()
quant_mode = getattr(tensorrt_llm_baichuan, 'quant_mode', QuantMode(0))
n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head = parse_bin_config(
Path(dir_path) / 'config.ini')
np_dtype = np.float16 if fp16 else np.float32
def fromfile(dir_path, name, shape=None, dtype=None):
dtype = np_dtype if dtype is None else dtype
p = dir_path + '/' + name
if Path(p).exists():
t = np.fromfile(p, dtype=dtype)
if shape is not None:
t = t.reshape(shape)
return t
return None
def set_smoothquant_scale_factors(module,
pre_scale_weight,
dir_path,
basename,
shape,
per_tok_dyn,
per_channel,
is_qkv=False,
rank=None):
suffix = "bin"
if per_channel:
if rank is not None:
suffix = f"{rank}." + suffix
suffix = "col." + suffix
col_shape = shape if (per_channel or is_qkv) else [1, 1]
if per_tok_dyn:
if pre_scale_weight is not None:
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
if is_qkv and not per_channel:
t = fromfile(dir_path,
f"{basename}scale_w_quant_orig.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
else:
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
np.float32)
pre_scale_weight.value = t
if is_qkv:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{rank}.{suffix}",
col_shape, np.float32)
else:
t = fromfile(dir_path,
f"{basename}scale_y_accum_quant.{suffix}",
col_shape, np.float32)
module.per_channel_scale.value = t
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
np.float32)
module.act_scale.value = t
def set_smoother(module, dir_path, base_name, shape, rank):
suffix = f"{rank}.bin"
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
np.float32)
module.smoother.value = t
# Determine the quantization mode.
quant_mode = getattr(tensorrt_llm_baichuan, "quant_mode", QuantMode(0))
if quant_mode.is_int8_weight_only():
plugin_weight_only_quant_type = torch.int8
elif quant_mode.is_int4_weight_only():
plugin_weight_only_quant_type = torch.quint4x2
# Do we use SmoothQuant?
use_smooth_quant = quant_mode.has_act_and_weight_quant()
# Do we use quantization per token?
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
# Do we use quantization per channel?
quant_per_channel = quant_mode.has_per_channel_scaling()
# Do we use INT4/INT8 weight-only?
use_weight_only = quant_mode.is_weight_only()
# Int8 KV cache
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
# Debug
suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
# The type of weights.
w_type = np_dtype if not use_smooth_quant else np.int8
if mapping.is_first_pp_rank():
tensorrt_llm_baichuan.vocab_embedding.weight.value = (fromfile(
dir_path, 'vocab_embedding.weight.bin', [vocab_size, n_embd]))
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.ln_f.weight.value = (fromfile(
dir_path, 'ln_f.weight.bin'))
# share input embedding
lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin',
[vocab_size, n_embd])
if model_version.startswith('v2'):
# baichuan v2 models use NormHead
tensorrt_llm.logger.info(
f'Normalizing lm_head.weight for {model_version}')
lm_head_weight = lm_head_weight / np.linalg.norm(
lm_head_weight, axis=1, keepdims=True)
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = tensorrt_llm_baichuan.lm_head.out_features * mapping.tp_size
pad_width = vocab_size_padded - vocab_size
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)),
'constant',
constant_values=0)
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.lm_head.weight.value = np.ascontiguousarray(
split(lm_head_weight, mapping.tp_size, mapping.tp_rank))
layers_range = list(
range(mapping.pp_rank * tensorrt_llm_baichuan.num_layers,
(mapping.pp_rank + 1) * tensorrt_llm_baichuan.num_layers, 1))
for i in layers_range:
n_groups = n_head // n_kv_head
c_attn_out_dim = (
3 * n_embd // mapping.tp_size) if not multi_query_mode else (
n_embd // mapping.tp_size +
(n_embd // n_head * n_groups) // mapping.tp_size * 2)
idx = i - mapping.pp_rank * tensorrt_llm_baichuan.num_layers
tensorrt_llm_baichuan.layers[idx].input_layernorm.weight.value = (
fromfile(dir_path,
'model.layers.' + str(i) + '.input_layernorm.weight.bin'))
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.weight.' + suffix,
[n_embd, c_attn_out_dim], w_type)
if t is not None:
dst = tensorrt_llm_baichuan.layers[idx].attention.qkv.weight
if use_smooth_quant:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_baichuan.layers[idx].attention.qkv,
tensorrt_llm_baichuan.layers[idx].input_layernorm.
scale_to_int,
dir_path,
'model.layers.' + str(i) + '.attention.query_key_value.',
[1, c_attn_out_dim],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank,
is_qkv=True)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
i].attention.qkv.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dst = tensorrt_llm_baichuan.layers[idx].attention.dense.weight
t = fromfile(
dir_path,
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
[n_embd // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dense_scale = getattr(tensorrt_llm_baichuan.layers[idx].attention,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_baichuan.layers[idx].attention.dense, dense_scale,
dir_path, 'model.layers.' + str(i) + '.attention.dense.',
[1, n_embd], quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_baichuan.layers[idx].attention.dense,
dir_path,
'model.layers.' + str(i) + '.attention.dense',
[1, n_embd // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[
i].attention.dense.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
dst = tensorrt_llm_baichuan.layers[idx].post_layernorm.weight
dst.value = fromfile(
dir_path, 'model.layers.' + str(i) + '.post_layernorm.weight.bin')
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.fc.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_baichuan.layers[
idx].mlp.fc.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_baichuan.layers[idx].mlp.fc,
tensorrt_llm_baichuan.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.fc.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_baichuan.layers[i].mlp.fc.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[i].mlp.fc.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_baichuan.layers[
idx].mlp.fc.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.gate.weight.' + suffix,
[n_embd, inter_size // mapping.tp_size], w_type)
if use_smooth_quant:
tensorrt_llm_baichuan.layers[
idx].mlp.gate.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
set_smoothquant_scale_factors(
tensorrt_llm_baichuan.layers[idx].mlp.gate,
tensorrt_llm_baichuan.layers[idx].post_layernorm.scale_to_int,
dir_path,
'model.layers.' + str(i) + '.mlp.gate.',
[1, inter_size // mapping.tp_size],
quant_per_token_dyn,
quant_per_channel,
rank=mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_baichuan.layers[i].mlp.gate.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[i].mlp.gate.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_baichuan.layers[
idx].mlp.gate.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
t = fromfile(dir_path,
'model.layers.' + str(i) + '.mlp.proj.weight.' + suffix,
[inter_size // mapping.tp_size, n_embd], w_type)
if use_smooth_quant:
tensorrt_llm_baichuan.layers[
idx].mlp.proj.weight.value = np.ascontiguousarray(
np.transpose(t, [1, 0]))
proj_scale = getattr(tensorrt_llm_baichuan.layers[idx].mlp,
"quantization_scaling_factor", None)
set_smoothquant_scale_factors(
tensorrt_llm_baichuan.layers[idx].mlp.proj, proj_scale,
dir_path, 'model.layers.' + str(i) + '.mlp.proj.', [1, n_embd],
quant_per_token_dyn, quant_per_channel)
set_smoother(tensorrt_llm_baichuan.layers[idx].mlp.proj, dir_path,
'model.layers.' + str(i) + '.mlp.proj',
[1, inter_size // mapping.tp_size], mapping.tp_rank)
elif use_weight_only:
dst = tensorrt_llm_baichuan.layers[i].mlp.proj.weight
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
torch.tensor(t), plugin_weight_only_quant_type)
dst.value = processed_torch_weights.numpy()
scales = tensorrt_llm_baichuan.layers[i].mlp.proj.per_channel_scale
scales.value = torch_weight_scales.numpy()
else:
tensorrt_llm_baichuan.layers[idx].mlp.proj.weight.value = (
np.ascontiguousarray(np.transpose(t, [1, 0])))
if use_int8_kv_cache:
t = fromfile(
dir_path, 'model.layers.' + str(i) +
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
np.float32)
tensorrt_llm_baichuan.layers[
idx].attention.kv_orig_quant_scale.value = 1.0 / t
tensorrt_llm_baichuan.layers[
idx].attention.kv_quant_orig_scale.value = t
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def load_from_awq_baichuan(tensorrt_llm_baichuan: BaichuanForCausalLM,
quant_ckpt_path,
model_version,
quantize_lm_head,
mapping=Mapping(),
dtype="float16",
bin_model_dir=None):
tensorrt_llm.logger.info(
'Loading weights from groupwise AWQ Baichuan checkpoint...')
tik = time.time()
if quant_ckpt_path.endswith(".npz"):
awq_baichuan = np.load(quant_ckpt_path)
awq_prefix = "_np:"
awq_suffix_list = [
":weight",
":weights_scaling_factor",
":prequant_scaling_factor",
]
awq_key_list = [
"vocab_embedding:weight", # vocab_embedding
"lm_head", # lm_head
"final_layernorm:weight", # ln_f
"attention:qkv:", # attention.qkv
"", # qkv suffix
"attention:dense", # attention.dense
"mlp:gate", # mlp.gate
"mlp:proj", # mlp.proj
"mlp:fc", # mlp.fc
"input_layernorm:weight", # input_layernorm
"post_layernorm:weight", # post_layernorm
]
split_sym = ":"
def load(key):
v = torch.from_numpy(awq_baichuan[awq_prefix + key])
if "weights_scaling_factor" in key:
v *= 7 # For AMMO *.npz checkpoints
return v
group_size = load("layers:0:attention:dense:weight").numel() // load(
"layers:0:attention:dense:weights_scaling_factor").numel()
else:
assert False, "Unsupported AWQ quantized checkpoint format"
quant_mode = getattr(tensorrt_llm_baichuan, 'quant_mode', QuantMode(0))
# Int8 KV cache
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def fromfile(dir_path, name, shape=None, dtype=None):
p = dir_path + '/' + name
if Path(p).exists():
t = np.fromfile(p, dtype=dtype)
if shape is not None:
t = t.reshape(shape)
return t
return None
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
tensorrt_llm.logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
assert False, "Invalid TP size"
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank]
def AWQ_quantize_pack_preprocess(weight, scale):
weight /= scale.repeat_interleave(group_size, dim=0)
qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
int4_weight = preprocessor(packer(qweight_int8.cpu()), torch.quint4x2)
return int4_weight.view(torch.int8).cpu().numpy()
def process_and_assign_weight(mOp, v, tp_dim=0):
weight = v[0].T.contiguous()
[k, n] = weight.shape
weight = torch_split(weight, tp_dim)
amax = v[1].reshape((n, k // group_size)).T.contiguous()
amax = torch_split(amax, tp_dim)
pre_quant_scale = v[2].reshape((1, k))
if tp_dim == 0:
pre_quant_scale = torch_split(pre_quant_scale, 1)
scale = amax / 8.0
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale)
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
mOp.pre_quant_scale.value = pre_quant_scale.to(
torch_dtype).cpu().numpy()
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale):
# deSmooth and reSmooth
[k, n] = weight.shape
if quant_ckpt_path.endswith("pt"):
# NPZ files are already re-smoothed
weight *= pre_quant_scale.repeat((n, 1)).transpose(1,
0).contiguous()
weight /= avg_pre_quant_scale.repeat(
(n, 1)).transpose(1, 0).contiguous()
# Get scale
weight_t = weight.T.contiguous()
weight_t = weight_t.reshape(n, k // group_size, group_size)
weight_t = torch.abs(weight_t.reshape(-1, group_size))
amax, idx = weight_t.max(1)
amax = amax.reshape(n, k // group_size).T.contiguous()
scale = amax / 8
return weight, scale
def process_and_assign_qkv_weight(prefix, mOp):
q_weight = load(prefix + "q" + awq_key_list[4] +
awq_suffix_list[0]).T.contiguous()
k_weight = load(prefix + "k" + awq_key_list[4] +
awq_suffix_list[0]).T.contiguous()
v_weight = load(prefix + "v" + awq_key_list[4] +
awq_suffix_list[0]).T.contiguous()
dim_k = q_weight.shape[0]
q_weight = torch_split(q_weight, 1)
k_weight = torch_split(k_weight, 1)
v_weight = torch_split(v_weight, 1)
q_pre_quant_scale = load(prefix + "q" + awq_key_list[4] +
awq_suffix_list[2]).reshape((1, dim_k))
k_pre_quant_scale = load(prefix + "k" + awq_key_list[4] +
awq_suffix_list[2]).reshape((1, dim_k))
v_pre_quant_scale = load(prefix + "v" + awq_key_list[4] +
awq_suffix_list[2]).reshape((1, dim_k))
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
v_pre_quant_scale) / 3.0
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
qkv_pre_quant_scale)
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
qkv_pre_quant_scale)
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
qkv_pre_quant_scale)
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
torch_dtype).cpu().numpy()
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale)
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
# Load weights from AWQ checkpoint into TRT-LLM module
# 1. vocab_embedding
v = load(awq_key_list[0])
# TRT-LLM requires vocab_size to be multiple of 64 for successful GEMM
if v.shape[0] % 64 != 0:
v = torch.nn.functional.pad(v, [0, 0, 0, 64 - v.shape[0] % 64])
if mapping.is_first_pp_rank():
tensorrt_llm_baichuan.vocab_embedding.weight.value = v.to(
torch_dtype).cpu().numpy()
# 2. lm_head
if not quantize_lm_head:
assert (
awq_key_list[1] + awq_suffix_list[1]
) not in awq_baichuan, "lm_head is quantized in checkpoint, please re-generate it"
original_v = load(awq_key_list[1] + awq_suffix_list[0])
if model_version.startswith('v2'):
# baichuan v2 models use NormHead which is not quantized
tensorrt_llm.logger.info(
f'Normalizing lm_head.weight for {model_version}')
v = torch_split(torch.nn.functional.normalize(original_v), 0)
else:
v = torch_split(original_v, 0)
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.lm_head.weight.value = v.to(
torch_dtype).cpu().numpy()
else:
assert not model_version.startswith(
'v2'), f"Do not support quantizing lm_head for {model_version}"
v = [load(awq_key_list[1] + suf) for suf in awq_suffix_list]
if v[0].shape[0] % 64 != 0:
v[0] = torch.nn.functional.pad(v[0],
[0, 0, 0, 64 - v[0].shape[0] % 64])
scale_align = 64 * (v[0].shape[1] // group_size)
v[1] = v[1].reshape(-1)
v[1] = torch.nn.functional.pad(
v[1], [0, scale_align - v[1].shape[0] % scale_align], value=1)
if mapping.is_last_pp_rank():
process_and_assign_weight(tensorrt_llm_baichuan.lm_head, v, 1)
# 3. ln_f
v = load(awq_key_list[2])
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.ln_f.weight.value = v.to(
torch_dtype).cpu().numpy()
# 4. Weights inside each layer
num_hidden_layers = tensorrt_llm_baichuan.num_layers
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for l in layers_range:
layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage
prefix = "layers" + split_sym + str(layer_idx) + split_sym
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
layer = tensorrt_llm_baichuan.layers[layer_idx]
# 4.1 attention.qkv
process_and_assign_qkv_weight(prefix + awq_key_list[3],
layer.attention.qkv)
# 4.2 attention.dense
v = [load(prefix + awq_key_list[5] + suf) for suf in awq_suffix_list]
process_and_assign_weight(layer.attention.dense, v, 0)
# 4.3 mlp.gate
v = [load(prefix + awq_key_list[6] + suf) for suf in awq_suffix_list]
process_and_assign_weight(layer.mlp.gate, v, 1)
# 4.4 mlp.proj
v = [load(prefix + awq_key_list[7] + suf) for suf in awq_suffix_list]
process_and_assign_weight(layer.mlp.proj, v, 0)
# 4.5 mlp.fc
v = [load(prefix + awq_key_list[8] + suf) for suf in awq_suffix_list]
process_and_assign_weight(layer.mlp.fc, v, 1)
# 4.6 input_layernorm
v = load(prefix + awq_key_list[9])
layer.input_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
# 4.7 post_layernorm
v = load(prefix + awq_key_list[10])
layer.post_layernorm.weight.value = v.to(torch_dtype).cpu().numpy()
# 4.8 attention.kv_quant_orig_scale / kv_quant_orig_scale
if use_int8_kv_cache:
assert bin_model_dir, "You must pass --bin_model_dir to tell TRT-LLM where to look for scales of INT8 kv cache."
t = fromfile(
bin_model_dir, 'model.layers.' + str(layer_idx) +
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
np.float32)
assert t is not None, f"{bin_model_dir} does not contain model.layers.{layer_idx}.attention.query_key_value.scale_y_quant_orig.bin"
layer.attention.kv_orig_quant_scale.value = 1.0 / t
layer.attention.kv_quant_orig_scale.value = t
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
tensorrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def load_from_gptq_baichuan(tensorrt_llm_baichuan,
quant_ckpt_path,
model_version,
mapping=Mapping(),
dtype="float16",
bin_model_dir=None):
tensorrt_llm.logger.info(
'Loading weights from groupwise GPTQ Baichuan safetensors...')
tik = time.time()
gptq_baichuan = safe_open(quant_ckpt_path, framework="pt", device=0)
gptq_prefix = "model."
gptq_suffix_list = [".qweight", ".qzeros", ".scales"]
gptq_key_list = [
"embed_tokens.weight", # vocab_embedding
"lm_head.weight", # lm_head
"norm.weight", # ln_f
"self_attn.W_pack", # attention.qkv
"_proj", #
"self_attn.o_proj", # attention.dense
"mlp.up_proj", # mlp.gate
"mlp.down_proj", # mlp.proj
"mlp.gate_proj", # mlp.fc
"input_layernorm.weight", # input_layernorm
"post_attention_layernorm.weight", # post_layernorm
]
split_sym = "."
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
torch_dtype = str_dtype_to_torch(dtype)
def load(key, no_prefix=0):
if no_prefix:
return gptq_baichuan.get_tensor(key)
else:
return gptq_baichuan.get_tensor(gptq_prefix + key)
def torch_split(v, dim):
if v.shape[dim] % mapping.tp_size != 0:
tensorrt_llm.logger.error(
"Current weight shape is invalid for mapping.tp_size=" +
str(mapping.tp_size))
assert False, "Invalid TP size"
return v.split(v.shape[dim] // mapping.tp_size,
dim=dim)[mapping.tp_rank]
def unpack_int32_into_int8(w_packed):
# Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
w_unpacked = torch.zeros(w_packed_int4x2.shape[0],
w_packed_int4x2.shape[1] * 2,
dtype=torch.int8)
w_unpacked[:, ::2] = w_packed_int4x2 % 16
w_unpacked[:, 1::2] = w_packed_int4x2 // 16
return w_unpacked.contiguous()
def process_and_assign_weight(mOp, v, tp_dim=-1):
if tp_dim == -1:
qweight_int32, qzeros_int32, scales_fp16 = [
item.cpu() for item in v
]
else:
qweight_int32, qzeros_int32, scales_fp16 = [
torch_split(item, tp_dim).cpu() for item in v
]
USE_UINT4_INPUT = 1 # Set to true if checkpoint store UINT4 weights
USE_GPTQ_FOR_LLAMA = 1 # GPTQ-for-LLaMA added 1 to zeros
qweight_unpacked_int8 = unpack_int32_into_int8(
qweight_int32.T).T.contiguous() - 8
qweight_interleaved = preprocessor(packer(qweight_unpacked_int8),
torch.quint4x2).view(torch.int8)
# zeros = zeros * scales
qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
if not USE_UINT4_INPUT:
# Correcting UINT4 values back to INT4 order
mask_negative = qzeros_unpacked_int32[qzeros_unpacked_int32 < 0]
mask_positive = qzeros_unpacked_int32[qzeros_unpacked_int32 >= 0]
qzeros_unpacked_int32 = qzeros_unpacked_int32 + 16 * mask_negative - 16 * mask_positive
zeros_x_scales_fp16 = (-qzeros_unpacked_int32 + 8 * USE_UINT4_INPUT -
USE_GPTQ_FOR_LLAMA) * scales_fp16
zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
# return processed interleaved weight, original scales and zeros * scales
mOp.qweight.value = qweight_interleaved.cpu().numpy()
mOp.scale.value = scales_fp16.cpu().numpy()
mOp.zero.value = zeros_x_scales_fp16.cpu().numpy()
# Load weights from GPTQ checkpoint into TRT-LLM module
# 1. vocab_embedding
v = load(gptq_key_list[0])
if mapping.is_first_pp_rank():
tensorrt_llm_baichuan.vocab_embedding.weight.value = v.to(
torch_dtype).cpu().numpy()
# 2. lm_head
original_v = load(gptq_key_list[1], "no_prefix")
if model_version.startswith('v2'):
# baichuan v2 models use NormHead
tensorrt_llm.logger.info(
f'Normalizing lm_head.weight for {model_version}')
v = torch_split(torch.nn.functional.normalize(original_v), 0)
else:
v = torch_split(original_v, 0)
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.lm_head.weight.value = v.to(
torch_dtype).cpu().numpy()
# 3. ln_f
v = load(gptq_key_list[2])
if mapping.is_last_pp_rank():
tensorrt_llm_baichuan.ln_f.weight.value = v.to(
torch_dtype).cpu().numpy()
# 4. Weights inside each layer
num_hidden_layers = tensorrt_llm_baichuan.num_layers
layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
layers_range = list(
range(mapping.pp_rank * layers_per_pipeline_stage,
(mapping.pp_rank + 1) * layers_per_pipeline_stage, 1))
for l in layers_range:
layer_idx = l - mapping.pp_rank * layers_per_pipeline_stage
prefix = "layers" + split_sym + str(layer_idx) + split_sym
tensorrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
layer = tensorrt_llm_baichuan.layers[layer_idx]
# 4.1 attention.qkv
qkv_weight_list = []
for suf in gptq_suffix_list:
qkv_list = []
comp_part = load(prefix + gptq_key_list[3] + suf)
qkv = torch.chunk(comp_part, 3, 1)
for i in range(3):
comp_part = qkv[i]
comp_part = torch_split(comp_part, 1)
qkv_list.append(comp_part)
qkv_weight_list.append(torch.cat(qkv_list, dim=1))
process_and_assign_weight(layer.attention.qkv, qkv_weight_list)
# 4.2 attention.dense
v = [load(prefix + gptq_key_list[5] + suf) for suf in gptq_suffix_list]
process_and_assign_weight(layer.attention.dense, v, 0)
# 4.3 mlp.gate
v = [load(prefix + gptq_key_list[6] + suf) for suf in gptq_suffix_list]
process_and_assign_weight(layer.mlp.gate, v, 1)
# 4.4 mlp.proj
v = [load(prefix + gptq_key_list[7] + suf) for suf in gptq_suffix_list]
process_and_assign_weight(layer.mlp.proj, v, 0)
# 4.5 mlp.fc