From dd267fca729621cec18b6199b31671ed9513a82c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A4=80=EC=9E=AC?= <55151385+junejae@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:10:59 +0900 Subject: [PATCH] Add T5 GGUF loading support (#33389) * add: GGUFT5Converter * add: tensormapping for t5 * add: test code for t5 * fix: Remove whitespace from blank line * add: t5 fp16 tests * fix: whitespace formatting * fix: minor formatting * fix: testing every weights --- docs/source/en/gguf.md | 1 + src/transformers/integrations/ggml.py | 128 +++++++++++++++++- .../modeling_gguf_pytorch_utils.py | 17 ++- .../models/t5/tokenization_t5_fast.py | 2 +- tests/quantization/ggml/test_ggml.py | 56 +++++++- 5 files changed, 197 insertions(+), 7 deletions(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 01583cedbf4110..20531b990bc341 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -85,6 +85,7 @@ For now the supported model architectures are the architectures that have been v - StableLM - GPT2 - Starcoder2 +- T5 ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 7b5828176ffcf4..4a2740fcb30e1c 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -21,11 +21,11 @@ from array import array import numpy as np -from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers -from tokenizers.models import BPE +from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors +from tokenizers.models import BPE, Unigram from .. import AddedToken -from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter +from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter, T5Converter from ..utils import logging from ..utils.logging import tqdm @@ -148,6 +148,51 @@ ".output.": ".lm_head.", "output_norm": "ln_f", }, + "t5": { + "token_embd": "shared", + "dec.blk.{bid}.attn_q": "decoder.block.{bid}.layer.0.SelfAttention.q", + "dec.blk.{bid}.attn_k": "decoder.block.{bid}.layer.0.SelfAttention.k", + "dec.blk.{bid}.attn_v": "decoder.block.{bid}.layer.0.SelfAttention.v", + "dec.blk.{bid}.attn_o": "decoder.block.{bid}.layer.0.SelfAttention.o", + "dec.blk.{bid}.attn_rel_b": "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "dec.blk.{bid}.attn_norm": "decoder.block.{bid}.layer.0.layer_norm", + "dec.blk.{bid}.cross_attn_q": "decoder.block.{bid}.layer.1.EncDecAttention.q", + "dec.blk.{bid}.cross_attn_k": "decoder.block.{bid}.layer.1.EncDecAttention.k", + "dec.blk.{bid}.cross_attn_v": "decoder.block.{bid}.layer.1.EncDecAttention.v", + "dec.blk.{bid}.cross_attn_o": "decoder.block.{bid}.layer.1.EncDecAttention.o", + "dec.blk.{bid}.cross_attn_norm": "decoder.block.{bid}.layer.1.layer_norm", + "dec.blk.{bid}.ffn_gate": "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", + "dec.blk.{bid}.ffn_up": "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", + "dec.blk.{bid}.ffn_down": "decoder.block.{bid}.layer.2.DenseReluDense.wo", + "dec.blk.{bid}.ffn_norm": "decoder.block.{bid}.layer.2.layer_norm", + "dec.output_norm": "decoder.final_layer_norm", + "enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q", + "enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k", + "enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v", + "enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o", + "enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm", + "enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", + "enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", + "enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo", + "enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm", + "enc.output_norm": "encoder.final_layer_norm", + "output.weight": "lm_head.weight", + }, + "t5encoder": { + "token_embd": "shared", + "enc.blk.{bid}.attn_q": "encoder.block.{bid}.layer.0.SelfAttention.q", + "enc.blk.{bid}.attn_k": "encoder.block.{bid}.layer.0.SelfAttention.k", + "enc.blk.{bid}.attn_v": "encoder.block.{bid}.layer.0.SelfAttention.v", + "enc.blk.{bid}.attn_o": "encoder.block.{bid}.layer.0.SelfAttention.o", + "enc.blk.{bid}.attn_rel_b": "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", + "enc.blk.{bid}.attn_norm": "encoder.block.{bid}.layer.0.layer_norm", + "enc.blk.{bid}.ffn_gate": "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", + "enc.blk.{bid}.ffn_up": "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", + "enc.blk.{bid}.ffn_down": "encoder.block.{bid}.layer.1.DenseReluDense.wo", + "enc.blk.{bid}.ffn_norm": "encoder.block.{bid}.layer.1.layer_norm", + "enc.output_norm": "encoder.final_layer_norm", + }, "stablelm": { "token_embd": "model.embed_tokens", "blk": "model.layers", @@ -287,6 +332,19 @@ "vocab_size": "vocab_size", "attention.layer_norm_epsilon": "layer_norm_epsilon", }, + "t5": { + "context_length": "n_positions", + "block_count": "num_layers", + "feed_forward_length": "d_ff", + "embedding_length": "d_model", + "attention.key_length": "d_kv", + "attention.head_count": "num_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_epsilon": "layer_norm_epsilon", + "attention.relative_buckets_count": "relative_attention_num_buckets", + "decoder_start_token_id": "decoder_start_token_id", + "vocab_size": "vocab_size", + }, "stablelm": { "context_length": "max_position_embeddings", "block_count": "num_hidden_layers", @@ -636,6 +694,69 @@ def converted(self) -> Tokenizer: return tokenizer +class GGUFT5Converter(T5Converter): + def __init__(self, tokenizer_dict): + # set dummy data to avoid unnecessary merges calculation + tokenizer_dict["merges"] = ["dummy text"] + + self.proto = GGUFTokenizerSkeleton(tokenizer_dict) + self.token2id = {k: v for v, k in enumerate(self.proto.tokens)} + self.original_tokenizer = self.proto + self.additional_kwargs = {} + + def vocab(self, proto): + return list(zip(proto.tokens, proto.scores)) + + def normalizer(self, proto): + if getattr(self.original_tokenizer, "legacy", True): + sequence = [] + if getattr(self.original_tokenizer, "add_prefix_space", True): + sequence += [normalizers.Prepend(prepend="▁")] + sequence += [normalizers.Replace(pattern=" ", content="▁")] + return normalizers.Sequence(sequence) + return None # non-legacy, no normalizer + + def post_processor(self): + return processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.token2id[""]), + ], + ) + + def converted(self) -> Tokenizer: + vocab_scores = self.vocab(self.proto) + tokenizer = Tokenizer( + Unigram( + vocab_scores, + unk_id=self.proto.unk_token_id, + byte_fallback=False, + ) + ) + + # Tokenizer assemble + normalizer = self.normalizer(self.proto) + if normalizer is not None: + tokenizer.normalizer = normalizer + + replacement = "▁" + add_prefix_space = True + if hasattr(self.original_tokenizer, "add_prefix_space"): + add_prefix_space = self.original_tokenizer.add_prefix_space + + pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space) + if pre_tokenizer is not None: + tokenizer.pre_tokenizer = pre_tokenizer + + tokenizer.decoder = self.decoder(replacement, add_prefix_space) + post_processor = self.post_processor() + if post_processor: + tokenizer.post_processor = post_processor + + return tokenizer + + GGUF_TO_FAST_CONVERTERS = { "llama": GGUFLlamaConverter, "qwen2": GGUFQwen2Converter, @@ -646,6 +767,7 @@ def converted(self) -> Tokenizer: "stablelm": GGUFGPTConverter, "gpt2": GGUFGPTConverter, "starcoder2": GGUFGPTConverter, + "t5": GGUFT5Converter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index b1d7b896085476..171b2f4d15b122 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -94,6 +94,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): # to add this patch to ensure things work correctly on our side. if "llama" in architecture and "mistral" in model_name: updated_architecture = "mistral" + # FIXME: Currnetly this implementation is only for flan-t5 architecture. + # It needs to be developed for supporting legacy t5. + elif "t5" in architecture or "t5encoder" in architecture: + parsed_parameters["config"]["tie_word_embeddings"] = False + parsed_parameters["config"]["is_gated_act"] = True + updated_architecture = "t5" else: updated_architecture = architecture @@ -191,6 +197,13 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): else: weights = reverse_reshape_bias(weights, num_heads, n_embed) + bid = None + if architecture in ("t5", "t5encoder"): + for chunk in name.split("."): + if chunk.isdigit(): + bid = int(chunk) + break + if architecture == "gpt2": if ( "attn_qkv.weight" in name @@ -209,8 +222,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): continue for tensor_name in tensor_key_mapping: - if tensor_name in name: - name = name.replace(tensor_name, tensor_key_mapping[tensor_name]) + if tensor_name.format(bid=bid) in name: + name = name.replace(tensor_name.format(bid=bid), tensor_key_mapping[tensor_name].format(bid=bid)) # Use copy to avoid errors with numpy and pytorch parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index 0a92803f165846..4c3fa950559637 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -117,7 +117,7 @@ def __init__( kwargs["from_slow"] = True super().__init__( - vocab_file, + vocab_file=vocab_file, tokenizer_file=tokenizer_file, eos_token=eos_token, unk_token=unk_token, diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 6e47d46f07c47e..ddc791e96a6489 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -15,7 +15,7 @@ import tempfile import unittest -from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer +from transformers import AddedToken, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from transformers.testing_utils import ( require_gguf, require_torch_gpu, @@ -48,6 +48,8 @@ class GgufIntegrationTests(unittest.TestCase): falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf" falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf" original_flacon7b_model_id = "tiiuae/falcon-7b" + t5_model_id = "repetitio/flan-t5-small" + original_t5_model_id = "google/flan-t5-small" stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF" stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF" original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b" @@ -92,6 +94,8 @@ class GgufIntegrationTests(unittest.TestCase): q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf" fp16_falcon7b_model_id = "falcon-7b-fp16.gguf" q2_k_falcon40b_model_id = "tiiuae-falcon-40b-Q2_K.gguf" + fp16_t5_model_id = "flan-t5-small-f16.gguf" + q8_0_t5_model_id = "flan-t5-small-q8_0.gguf" fp16_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B.gguf" fp16_gpt2_model_id = "gpt2.f16.gguf" q8_gpt2_model_id = "gpt2.Q8_0.gguf" @@ -487,6 +491,56 @@ def test_bloom_weights_conversion_fp16(self): self.assertTrue(quantized_param.shape == original_param.shape) torch.testing.assert_close(quantized_param, original_param) + def test_t5_f16(self): + tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.fp16_t5_model_id) + model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, gguf_file=self.fp16_t5_model_id, device_map="auto", torch_dtype=torch.float16 + ) + + T5_EXAMPLE_TEXT = "translate English to German: How old are you?" + + text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Wie ich er?" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_t5_q8_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.t5_model_id, gguf_file=self.q8_0_t5_model_id) + model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, gguf_file=self.q8_0_t5_model_id, device_map="auto", torch_dtype=torch.float16 + ) + + T5_EXAMPLE_TEXT = "translate English to German: How old are you?" + + text = tokenizer(T5_EXAMPLE_TEXT, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Wie ich er?" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_t5_weights_conversion_fp16(self): + quantized_model = AutoModelForSeq2SeqLM.from_pretrained( + self.t5_model_id, + gguf_file=self.fp16_t5_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + original_model = AutoModelForSeq2SeqLM.from_pretrained( + self.original_t5_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + quantized_state_dict = quantized_model.state_dict() + original_state_dict = original_model.state_dict() + + for (quantized_name, quantized_param), (original_name, original_param) in zip( + quantized_state_dict.items(), original_state_dict.items() + ): + self.assertTrue(quantized_param.shape == original_param.shape) + torch.testing.assert_close(quantized_param, original_param, rtol=5e-04, atol=5e-04) + def test_gpt2_q8(self): tokenizer = AutoTokenizer.from_pretrained(self.gpt2_model_id, gguf_file=self.q8_gpt2_model_id) model = AutoModelForCausalLM.from_pretrained(