From 3e2c9024531afc38043e3d58029d9dc517fa9a09 Mon Sep 17 00:00:00 2001 From: irfan sharif Date: Tue, 5 Dec 2023 21:00:54 +0000 Subject: [PATCH] Scaffolding for `gpt-fast` example --- 06_gpu_and_ml/gpt-fast/GPTQ.py | 512 +++++++++++++ 06_gpu_and_ml/gpt-fast/LICENSE | 11 + 06_gpu_and_ml/gpt-fast/README.md | 77 ++ 06_gpu_and_ml/gpt-fast/__init__.py | 0 06_gpu_and_ml/gpt-fast/generate.py | 568 ++++++++++++++ 06_gpu_and_ml/gpt-fast/modal.py | 519 +++++++++++++ 06_gpu_and_ml/gpt-fast/model.py | 317 ++++++++ 06_gpu_and_ml/gpt-fast/quantize.py | 824 +++++++++++++++++++++ 06_gpu_and_ml/gpt-fast/tp.py | 177 +++++ 06_gpu_and_ml/llm-frontend/index.html | 16 +- 06_gpu_and_ml/text_generation_inference.py | 4 + 10_integrations/streamlit/app.py | 1 - 12 files changed, 3021 insertions(+), 5 deletions(-) create mode 100644 06_gpu_and_ml/gpt-fast/GPTQ.py create mode 100644 06_gpu_and_ml/gpt-fast/LICENSE create mode 100644 06_gpu_and_ml/gpt-fast/README.md create mode 100644 06_gpu_and_ml/gpt-fast/__init__.py create mode 100644 06_gpu_and_ml/gpt-fast/generate.py create mode 100644 06_gpu_and_ml/gpt-fast/modal.py create mode 100644 06_gpu_and_ml/gpt-fast/model.py create mode 100644 06_gpu_and_ml/gpt-fast/quantize.py create mode 100644 06_gpu_and_ml/gpt-fast/tp.py diff --git a/06_gpu_and_ml/gpt-fast/GPTQ.py b/06_gpu_and_ml/gpt-fast/GPTQ.py new file mode 100644 index 000000000..4b41273ca --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/GPTQ.py @@ -0,0 +1,512 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys + +import torch + +lm_evaluation_harness_path = "/".join( + os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"] +) +sys.path.insert(0, lm_evaluation_harness_path) + +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill +from generate import encode_tokens +from torch.utils._pytree import tree_flatten, tree_unflatten + +aten = torch.ops.aten + +try: + import lm_eval + + class InputRecorder(lm_eval.base.BaseLM): + """ + This is a fake evaluation wrapper that just records the inputs + so that they can be used in calibration. + + If pad_calibration_inputs is enabled, the input recorder will take + each input and pad/truncate it down to the calibration_seq_length. + It will also edit the model embeddings to be zero for the 0 token used + in padding and avoid any inputs with the 0 token. + + If not, it will only truncate inputs to the desired length. + """ + + def __init__( + self, + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs=False, + ): + super().__init__() + self._model = model + self._tokenizer = tokenizer + self._device = torch.device("cpu") + self.vocab_size = model.config.vocab_size + self.calibration_seq_length = calibration_seq_length + self.pad_calibration_inputs = pad_calibration_inputs + self.inputs = None + + if self.pad_calibration_inputs: + # This is needed for the pad_calibration_inputs option + # to work properly, the 0 token's embeddings are set to 0 so that + # the padded inputs will not affect the model numerics. This token isn't used + # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs + # where it appears + try: + if isinstance(self._model.transformer.wte, nn.Embedding): + self.mod.transformer.wte.weight.data[0, :] *= 0 + except Exception: + print( + "Did not find embeddings in model.transformer.wte, disabling padding" + ) + self.pad_calibration_inputs = False + + @property + def eot_token_id(self): + return self._tokenizer.eos_id() + + @property + def max_length(self): + return self.calibration_seq_length + + @property + def max_gen_toks(self): + return 50 + + @property + def batch_size(self): + return 1 + + @property + def device(self): + return self._device + + def tok_encode(self, string: str): + encoded = encode_tokens( + self._tokenizer, + string, + bos=True, + eos=False, + device=self._device, + ) + # encoded is a pytorch tensor, but some internal logic in the + # eval harness expects it to be a list instead + # TODO: verify this for multi-batch as well + encoded = encoded.tolist() + return encoded + + def tok_decode(self, tokens): + decoded = self._tokenizer.decode(tokens) + return decoded + + def add_input(self, args): + if self.inputs is None: + self.inputs = [MultiInput([arg]) for arg in args] + else: + self.inputs = [ + multi.add_input(arg) + for (multi, arg) in zip(self.inputs, args) + ] + + def get_recorded_inputs(self): + return self.inputs + + def _model_call(self, inps): + inps = inps.squeeze(0) + T = len(inps) + if ( + # can't use inputs that are too short when padding disabled + ( + T < self.calibration_seq_length + and not self.pad_calibration_inputs + ) + or + # can't use inputs that actually use token we use for padding + (self.pad_calibration_inputs and 0 in inps) + ): + # give random output + return torch.randn( + (1, T, self.vocab_size), + dtype=torch.bfloat16, + device=self._device, + ) + + # pad or truncate to the right size + if T >= self.calibration_seq_length: + inps = inps[: self.calibration_seq_length] + else: + inps = F.pad(inps, (0, self.calibration_seq_length - T)) + + max_new_tokens = 1 + ( + seq, + input_pos, + max_seq_length, + ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( + self._model, inps, max_new_tokens, self.max_length + ) + x = seq.index_select(0, input_pos).view(1, -1) + self.add_input((x, input_pos)) + + # output `something` with correct shape to keep eval going + return torch.randn( + (1, T, self.vocab_size), + dtype=torch.bfloat16, + device=self._device, + ) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + +except ImportError: + pass + + +class MultiInput: + def __init__(self, inputs): + self.values = list(inputs) + + def add_input(self, input): + self.values.append(input) + return self + + def __getitem__(self, slice): + return MultiInput(self.values[slice]) + + def cuda(self): + self.values = [ + val.cuda() if isinstance(val, torch.Tensor) else val + for val in self.values + ] + + +class GenericGPTQRunner(fx.Interpreter): + """ + This is a generic GPTQ runner that takes an existing model and applies GPTQ. + It uses torch._dynamo.export to obtain a graph of the model and then hooks + into function calls and when it detects a linear, it applies GPTQ to the weight + given the calibration of inputs passed in at initialization. It puts the results + into the state_dict so that the quantized model weights/qparams can be loaded + directly into the model. + + This class is expected to work in concert with a GPTQSimpleQuantizer + class to define the specific type of quantization being done. + """ + + def __init__( + self, + model, + inputs: MultiInput, + blocksize=128, + percdamp=0.01, + groupsize=128, + ): + self.id_to_name = { + id(value): name + for name, value in dict(model.named_parameters()).items() + } + + # trace model for one input + one_input = [multi.values[0] for multi in inputs] + exported_model = torch._dynamo.export( + model, aten_graph=True, pre_dispatch=True, tracing_mode="fake" + )(*one_input) + super().__init__(exported_model.graph_module) + self.new_state_dict = model.state_dict() + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.inputs = inputs + self.gptq_done = False + self.debug = False + + def configure_quantization_mode( + self, + get_qparams_func, + quantize_func, + dequantize_func, + combine_qparams_list_func, + make_names_and_values_dict_func, + skip_layer_func, + ): + # these functions need to already be curried with all inputs other than weight, qparams + self.get_qparams_func = ( + get_qparams_func # accepts [2d weight tensor], outputs qparams. + ) + + self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype + + self.dequantize_func = dequantize_func + # accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float, + # assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior + + self.combine_qparams_list_func = combine_qparams_list_func + # accepts [`list` of qparams] from quantizing one group at a time, + # outputs a qparams object that could be passed into quant/dequantize_func + + self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer + + self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict + # note any final packing for storage should happen here + return self + + def run(self): + assert ( + self.get_qparams_func is not None + ), "need to configure quantization mode before running" + self.gptq_done = True + super().run(*self.inputs) + + def get_quantized_state_dict(self): + assert ( + self.gptq_done + ), "need to run GPTQRunner before you can get_quantized_state_dict" + quantized_state_dict = self.new_state_dict + # Don't want to store/load the kv_cache so remove it from the state_dict + del_list = [] + for param_fqn in quantized_state_dict: + if "kv_cache" in param_fqn: + del_list.append(param_fqn) + for param_fqn in del_list: + quantized_state_dict.pop(param_fqn) + return quantized_state_dict + + def call_function(self, target, args, kwargs, skip_quant=False): + def tensors_to_cuda(args): + new_args = [] + for x in args: + new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) + return new_args + + # flatten args and kwargs together + flat_args, spec = tree_flatten((args, kwargs)) + # move all single tensors to cuda, will move MultiInputs to cuda one at a time + flat_args = tensors_to_cuda(flat_args) + + has_multi_input = MultiInput in [type(x) for x in flat_args] + if has_multi_input: + # Just some trickery to convert + # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b] + multi_input_count = max( + [ + len(x.values) if isinstance(x, MultiInput) else 1 + for x in flat_args + ] + ) + transposed_args = list( + zip( + *[ + x.values + if isinstance(x, MultiInput) + else [x] * multi_input_count + for x in flat_args + ] + ) + ) + else: + transposed_args = [flat_args] + outputs = [] + + # check whether we apply GPTQ to this module + quantize_linear = ( + (target == aten.linear.default) # if its a linear + and id(args[1]) in self.id_to_name # and if we know the layer name + and not skip_quant # and if we weren't told to skip quantization + # and if the skip_layer_func doesn't say we should skip + and not ( + self.skip_layer_func is not None + and self.skip_layer_func(args[1]) + ) + ) # then we will quantize this linear layer/weight + + if quantize_linear: # instantiate variables for GPTQ + H = 0 + total_batches = 0 + + for inp in transposed_args: + inp = tensors_to_cuda(inp) + cur_args, cur_kwargs = tree_unflatten(inp, spec) + + if ( + quantize_linear + ): # calculate H instead of output (will run the linear eventually with updated weight) + x = cur_args[0].float() + shape = x.shape + n = 1 if len(shape) == 2 else shape[0] + H *= total_batches / (total_batches + n) + total_batches += n + x = ((2 / total_batches) ** (1 / 2)) * x.reshape( + -1, shape[-1] + ).t().float() + H += x.matmul(x.t()) + else: + # get output if its not a linear + out = super().call_function(target, cur_args, cur_kwargs) + + if isinstance(out, torch.Tensor): + outputs.append(out.cpu()) + else: + outputs.append(out) + + if quantize_linear: + mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1]) + W = args[1].to(H.device) + Q, DQ, qparams = self.faster_quant(H, W.detach()) + print(mod_fqn) + names_and_values_dict = self.make_names_and_values_dict_func( + Q, qparams + ) + + # delete old weight + if mod_fqn + ".weight" in self.new_state_dict: + self.new_state_dict.pop(mod_fqn + ".weight") + if len(args) > 2: + self.new_state_dict[mod_fqn + ".bias"] = args[2] + for name, value in names_and_values_dict.items(): + self.new_state_dict[mod_fqn + "." + name] = value + + # run linear with new weight to get corrected output + new_out = self.call_function( + target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True + ) + + if self.debug: + old_out = self.call_function( + target, + (args[0][:2], args[1], *args[2:]), + kwargs, + skip_quant=True, + ) + + def SQNR(x, y): + return 20 * torch.log10(torch.norm(x) / torch.norm(x - y)) + + DQ_after = self.dequantize_func(Q, qparams).to(W.dtype) + print( + "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) + ) # matches + + print( + "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) + ) # fine to not match + print( + "SQNR for output with GPTQ (hopefully 35+)", + torch.cat( + [ + SQNR(old.cpu(), new.cpu()).unsqueeze(0) + for (old, new) in zip( + old_out.values, new_out.values[:2] + ) + ] + ).mean(), + ) + + qparams2 = self.get_qparams_func(W) + Q2 = self.quantize_func(W, qparams2) + DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) + old_q_out = self.call_function( + target, + (args[0][:2], DQ2, *args[2:]), + kwargs, + skip_quant=True, + ) + + print( + "SQNR for output without GPTQ (should be less than above)", + torch.cat( + [ + SQNR(old.cpu(), old_q.cpu()).unsqueeze(0) + for (old, old_q) in zip( + old_out.values, old_q_out.values + ) + ] + ).mean(), + ) + return new_out + + return MultiInput(outputs) if has_multi_input else outputs[0] + + def faster_quant(self, H, W): + percdamp = self.percdamp + blocksize = self.blocksize + groupsize = self.groupsize + orig_dtype = W.dtype + W = W.detach().float() + _rows, columns = W.shape[0], W.shape[1] + device = W.device + + if groupsize == -1: + cur_qparams = self.get_qparams_func(W) + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros_like(W) + DQ = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(columns, device=device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + all_qparams = [] + for i1 in range(0, columns, blocksize): + i2 = min(i1 + blocksize, columns) + count = i2 - i1 + W1 = W[:, i1:i2].clone() + DQ1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if ( + groupsize != -1 and (i1 + i) % groupsize == 0 + ): # start of new group + cur_qparams = self.get_qparams_func( + W[:, (i1 + i) : (i1 + i + groupsize)] + ) + all_qparams.append(cur_qparams) + + q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten() + dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten() + + DQ1[:, i] = dq + Losses1[:, i] = (w - dq) ** 2 / d**2 + + err1 = (w - dq) / d + W1[:, i:] -= ( + err1.to(Hinv1.dtype) + .unsqueeze(1) + .matmul(Hinv1[i, i:].unsqueeze(0)) + ) + Err1[:, i] = err1 + + DQ[:, i1:i2] = DQ1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + + if all_qparams == []: + all_qparams.append(cur_qparams) + + # convert a list of qparams objects into a single one. enerally by + # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor + all_qparams = self.combine_qparams_list_func(all_qparams) + Q = self.quantize_func(DQ, all_qparams) + return Q, DQ.to(orig_dtype), all_qparams diff --git a/06_gpu_and_ml/gpt-fast/LICENSE b/06_gpu_and_ml/gpt-fast/LICENSE new file mode 100644 index 000000000..56f4d62a4 --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/LICENSE @@ -0,0 +1,11 @@ +Copyright 2023 Meta + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/06_gpu_and_ml/gpt-fast/README.md b/06_gpu_and_ml/gpt-fast/README.md new file mode 100644 index 000000000..2387dbdd8 --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/README.md @@ -0,0 +1,77 @@ +# gpt-fast on Modal + +This is a demo of https://github.com/pytorch-labs/gpt-fast running on +[Modal](https://modal.com). It demonstrates how to use speculative sampling, +quantized models, and pytorch compilation to achieve upwards of 125 tokens/s +with batch sizes of 1 (i.e. no vLLM-style continuous batching), on 7B models +running on individual A100 80GB GPUs. It's a multi-file Modal app that +integrates into an existing codebase (files other than `modal.py` were mostly +taken as-is from `pytorch-labs/gpt-fast`), makes of container-lifecyle +primitives, streams responses, and is also able to invoke already-deployed +functions. + +TODO: +- [ ] Make use of GPU checkpointing to avoid long cold starts. +- [ ] Doc-ify modal.py, publish to website. +- [ ] Make use of draft models for speculative sampling. + - [ ] Run them on secondary GPUs? +- [ ] Make use of tensor parallelism. +- [ ] Fix (gpt-fast?) bug where subsequent generations end up generating + using the prompt used to compile the model itself, or earlier prompt. + Maybe some internal tensor getting recycled? + +To run one-off inference: +``` + ۩ modal run gpt-fast.modal::main --prompt "Implement fibonacci in python" + \ --no-compile-model + ... + Loading model weights ... + Using int8 weight-only quantization! + Loading model weights took 11.08 seconds + Starting inference for prompt = 'Implement fibonacci in python' + with memoization. + + The time complexity should be O(n) + The space complexity should be O(n) + """ + + def fibonacci(n, mem=dict()): + if n == 0: + return 0 + if n == 1: + return 1 + if n in mem: + return mem[n] + Time for inference 1: 13.24 sec total, 7.55 tokens/sec + Bandwidth achieved: 51.91 GB/s + ... +``` + +Compile the model for faster inference, at the cost of much longer cold-starts: +``` + ۩ modal run gpt-fast.modal::main --prompt "Implement fibonacci in python" \ + --compile-model + ... + Running warmup inference ... + Model compilation time: 298.49 seconds + Starting inference for prompt = 'Implement fibonacci in python' + ... + Time for inference 1: 0.81 sec total, 123.54 tokens/sec + Bandwidth achieved: 856.83 GB/s +``` + +Deploy the model and run inference against a container that's already compiled +the pytorch model: +``` + ۩ modal deploy gpt-fast.modal + + # Should happen instantaneously once deployed model is fully compiled, at + # upwards of 125 tokens/sec. + ۩ modal run gpt-fast.modal::main --lookup-existing \ + --prompt "Add two numbers in python" --num-samples 10 +``` + +Run a web-version of the app using: +``` + ۩ modal serve gpt-fast.modal::app +``` diff --git a/06_gpu_and_ml/gpt-fast/__init__.py b/06_gpu_and_ml/gpt-fast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/06_gpu_and_ml/gpt-fast/generate.py b/06_gpu_and_ml/gpt-fast/generate.py new file mode 100644 index 000000000..48f7aa3fa --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/generate.py @@ -0,0 +1,568 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from sentencepiece import SentencePieceProcessor + +from .model import Transformer +from .tp import maybe_init_dist + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to( + dtype=torch.int + ) + + +def logits_to_probs( + logits, temperature: float = 1.0, top_k: Optional[int] = None +): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + **sampling_kwargs, +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token( + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +def speculative_decode( + model: Transformer, + draft_model: Transformer, + cur_token: torch.Tensor, + input_pos: int, + speculate_k: int, + **sampling_kwargs, +) -> torch.Tensor: + # draft model inference sequentially + device = cur_token.device + orig_input_pos = torch.tensor( + [input_pos], dtype=torch.int64, device=cur_token.device + ) + draft_tokens, draft_probs = decode_n_tokens( + draft_model, + cur_token.view(1, -1), + orig_input_pos.clone(), + speculate_k, + **sampling_kwargs, + ) + + draft_tokens = torch.cat(draft_tokens) + # parallel inference on target model using draft tokens + target_logits = model_forward( + model, + torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), + torch.arange( + input_pos, input_pos + speculate_k + 1, device=cur_token.device + ), + ) + target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) + draft_probs = torch.stack(draft_probs) + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p) + rejected_locations = ( + torch.rand_like(accept_draft_prob) > accept_draft_prob + ).nonzero() + + if rejected_locations.shape[0] == 0: # All draft tokens have been accepted + accept_length = speculate_k + 1 + last_token = multinomial_sample_one_no_sync(target_probs[-1]) + # fill last token into draft model + model_forward( + draft_model, + draft_tokens[-1].view(1, -1), + orig_input_pos + speculate_k, + ) + return torch.cat([draft_tokens, last_token]) + else: + accept_length = rejected_locations[0].item() + p = draft_probs[accept_length] + q = target_probs[accept_length] + new = q - p + new = torch.where(new > 0, new, 0.0) + new = new / new.sum() + next_token = multinomial_sample_one_no_sync(new) + return torch.cat([draft_tokens[:accept_length], next_token]) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + draft_model: Transformer, + speculate_k: Optional[int] = 8, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + is_speculative = draft_model is not None + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + if interactive: + max_seq_length = 350 + else: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = ( + max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + ) + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches( + max_batch_size=1, max_seq_length=max_seq_length + ) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill( + model, prompt.view(1, -1), input_pos, **sampling_kwargs + ) + if is_speculative: + prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + accept_counts = [0] * (speculate_k + 1) + + if is_speculative: + input_pos = ( + input_pos.item() + ) # for speculative decoding easier to keep on host + while input_pos < T_new - 1: + cur_token = next_token.view(()) + + next_tokens = speculative_decode( + model, + draft_model, + cur_token, + input_pos, + speculate_k, + **sampling_kwargs, + ) + + accept_counts[len(next_tokens) - 1] += 1 + num_added = min(T_new - input_pos - 1, len(next_tokens)) + seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[ + :num_added + ] + for i in next_tokens[:num_added,]: + callback(i) + input_pos = input_pos + num_added + next_token = next_tokens[-1] + else: + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + max_new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq[T + 1 :] = torch.cat(generated_tokens) + + generate_stats = {"accept_counts": accept_counts} + return seq, generate_stats + + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision, use_tp): + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from .quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 quantization!") + path_comps = checkpoint_path.name.split(".") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from .quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from .tp import apply_tp + + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +B_INST, E_INST = "[INST]", "[/INST]" + + +def generate_main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + draft_checkpoint_path: Optional[Path] = None, + speculate_k: int = 5, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + + global print + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + torch.cuda.set_device(rank) + if rank != 0: + # only print on rank 0 + def print(*args, **kwargs): + return None + + device = "cuda" + precision = torch.bfloat16 + is_speculative = draft_checkpoint_path is not None + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, use_tp) + + if is_speculative: + draft_model = _load_model( + draft_checkpoint_path, device, precision, use_tp + ) + else: + draft_model = None + + torch.cuda.synchronize() + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(model.parameters(), model.buffers()) + ] + ) + if compile: + if is_speculative and use_tp: + torch._inductor.config.triton.cudagraph_trees = ( + False # Bug with cudagraph trees in this case + ) + + if is_speculative: + global model_forward, logits_to_probs + model_forward = torch.compile( + model_forward, mode="reduce-overhead", fullgraph=True + ) + + global decode_one_token, prefill + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) + + # Uncomment to squeeze more perf out of prefill + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + aggregate_metrics = { + "tokens_per_sec": [], + "accept_counts": [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + torch.cuda.synchronize() + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode(".")[0] + done_generating = False + + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print("".join(buffer), end="", flush=True) + buffer.clear() + # print(, end='', flush=True) + + else: + + def callback(x): + return x + + t0 = time.perf_counter() + import contextlib + + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, metrics = generate( + model, + None, + encoded, + max_new_tokens, + draft_model=draft_model, + speculate_k=speculate_k, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + torch.cuda.synchronize() + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print("==========") + if is_speculative: + counts_aggregated = [ + sum(i) for i in zip(*aggregate_metrics["accept_counts"]) + ] + acceptance_probs = [ + i / sum(counts_aggregated) for i in counts_aggregated + ] + print(f"Acceptance probs: {acceptance_probs}") + print( + f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}" + ) + + print( + f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}" + ) + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Your CLI description.") + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument( + "--num_samples", type=int, default=5, help="Number of samples." + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=200, + help="Maximum number of new tokens.", + ) + parser.add_argument( + "--top_k", type=int, default=200, help="Top-k for sampling." + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Temperature for sampling.", + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), + help="Model checkpoint path.", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + parser.add_argument( + "--profile", type=Path, default=None, help="Profile path." + ) + parser.add_argument( + "--speculate_k", + type=int, + default=5, + help="Speculative execution depth.", + ) + parser.add_argument( + "--draft_checkpoint_path", + type=Path, + default=None, + help="Draft checkpoint path.", + ) + + args = parser.parse_args() + generate_main( + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.top_k, + args.temperature, + args.checkpoint_path, + args.compile, + args.compile_prefill, + args.profile, + args.draft_checkpoint_path, + args.speculate_k, + ) diff --git a/06_gpu_and_ml/gpt-fast/modal.py b/06_gpu_and_ml/gpt-fast/modal.py new file mode 100644 index 000000000..060ec74b7 --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/modal.py @@ -0,0 +1,519 @@ +# --- +# lambda-test: false +# --- + +import itertools +import queue +import subprocess +import threading +import time +from pathlib import Path +from typing import Optional + +from modal import Function, Image, Mount, Secret, Stub, asgi_app, gpu, method + +model = "meta-llama/Llama-2-7b-chat-hf" + + +def prepare_int4_quantized(): + # TODO(irfansharif): Replace with run_command once we fix + # https://linear.app/modal-labs/issue/MOD-1955/respect-gpu-specs-for-run-command. + subprocess.run( + [ + "python", + "quantize.py", + "--checkpoint_path", + f"checkpoints/{model}/model.pth", + "--mode", + "int4", + "--groupsize", + "32", + ], + check=True, + cwd="/gpt-fast", + ) + + +image = ( + Image.from_registry( + "nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04", add_python="3.11" + ) + .pip_install( + "torch", + pre=True, + index_url="https://download.pytorch.org/whl/nightly/cu121", + ) + .pip_install( + # Use the barebones hf-transfer package for maximum download speeds. No + # progress bar, but expect 700MB/s. This combines with the + # HF_HUB_ENABLE_HF_TRANSFER env var below, see: + # https://huggingface.co/docs/huggingface_hub/guides/download#faster-downloads. + "hf-transfer~=0.1", + "huggingface-hub", + "sentencepiece", + ) + .apt_install("git") + .run_commands( + "git clone https://github.com/pytorch-labs/gpt-fast && cd /gpt-fast && git checkout 3bcaaaf0" + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + .run_commands( + f"cd /gpt-fast && ./scripts/prepare.sh {model}", + secrets=[Secret.from_name("huggingface")], + gpu=gpu.A100(memory=80), + ) + .run_function( + prepare_int4_quantized, + gpu=gpu.A100(memory=80), + ) +) + +stub = Stub("gpt-fast", image=image) + +with stub.image.run_inside(): + import torch + from sentencepiece import SentencePieceProcessor + + from . import generate + from .generate import ( + B_INST, + E_INST, + _load_model, + encode_tokens, + ) + from .tp import maybe_init_dist + + +@stub.cls( + gpu=gpu.A100(memory=80), + timeout=10 * 60, # 10m + keep_warm=1, + container_idle_timeout=20 * 60, # 20m +) +class Model: + def __init__( + self, + compile_model: bool = True, + compile_prefill: bool = False, + use_base_model: bool = False, + use_speculative_sampling: bool = False, # NB: takes >10m to initialize, tripping up runners + ): + checkpoint = "model.pth" if use_base_model else "model_int8.pth" + checkpoint_path: Path = Path( + f"/gpt-fast/checkpoints/{model}/{checkpoint}" + ) + draft_checkpoint_path: Optional[Path] = None + if use_speculative_sampling: + if use_base_model: + draft_checkpoint_path = Path( + f"/gpt-fast/checkpoints/{model}/model_int8.pth" + ) + else: + draft_checkpoint_path = Path( + f"/gpt-fast/checkpoints/{model}/model_int4.g32.pth" + ) + + self.compile_model = compile_model + self.compile_prefill = compile_prefill + self.checkpoint_path = checkpoint_path + self.draft_checkpoint_path = draft_checkpoint_path + + def __enter__(self): + assert self.checkpoint_path.is_file(), self.checkpoint_path + if self.draft_checkpoint_path is not None: + assert ( + self.draft_checkpoint_path.is_file() + ), self.draft_checkpoint_path + + global print + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + torch.cuda.set_device(rank) + if rank != 0: + # only print on rank 0 + def print(*args, **kwargs): + return None + + self.device = "cuda" + precision = torch.bfloat16 + is_speculative = self.draft_checkpoint_path is not None + + t0 = time.time() + print("Loading model weights ...") + model = _load_model( + self.checkpoint_path, self.device, precision, use_tp + ) + + if is_speculative: + draft_model = _load_model( + self.draft_checkpoint_path, self.device, precision, use_tp + ) + else: + draft_model = None + + torch.cuda.synchronize() + print(f"Loading model weights took {time.time() - t0:.02f} seconds") + + if self.compile_model: + if is_speculative and use_tp: + torch._inductor.config.triton.cudagraph_trees = ( + False # Bug with cudagraph trees in this case + ) + + if is_speculative: + self.model_forward = torch.compile( + generate.model_forward, + mode="reduce-overhead", + fullgraph=True, + ) + + self.decode_one_token = torch.compile( + generate.decode_one_token, + mode="reduce-overhead", + fullgraph=True, + ) + + if self.compile_prefill: + self.prefill = torch.compile( + generate.prefill, fullgraph=True, dynamic=True + ) + + self.model = model + self.draft_model = draft_model + + if self.compile_model: + print("Running warmup inference ...") + t0 = time.time() + self.binary_model = None + self.generate_inner( + "How to print 'hello world' in python?", + num_samples=1, + max_new_tokens=100, + speculate_k=5, + temperature=0.8, + top_k=200, + interactive=False, + q=queue.Queue(), + sentinel=object(), + ) + print(f"Warmup inference took {time.time() - t0:.02f} seconds") + + @method() + def generate( + self, + prompt: str, + num_samples: int = 1, + max_new_tokens: int = 100, + speculate_k: int = 5, + temperature: float = 0.8, + top_k: int = 200, + interactive: bool = True, + ): + q = queue.Queue() + sentinel = object() + + # Use a separate thread to generate responses in order to stream them + # back to the client. + generation_thread = threading.Thread( + target=self.generate_inner, + args=( + prompt, + num_samples, + max_new_tokens, + speculate_k, + temperature, + top_k, + interactive, + q, + sentinel, + ), + ) + generation_thread.start() + + # NB: There are bugs in either pytorch or the gpt-fast repo. Inference + # occasionally hangs and server-logs show things like: + # + # :198: _run_code: block: [5,0,0], thread: [62,0,0] Assertion `index out of bounds: 0 <= tmp84 < 120` failed. + # + # Use a timeout and poll frequently. We kill the entire container when + # an input times out, which is annoying given the large model + # compilation times. + print( + f"[{prompt=},{num_samples=}] Waiting for inference to complete (timeout=30s) ...", + ) + + start_time = time.time() + while True: + time.sleep(0.1) + + if not generation_thread.is_alive(): + return + + if time.time() - start_time > 30: # > 30s + print( + f"[{prompt=}] Timed out waiting for inference to complete" + ) + yield "" + return + + try: + data = q.get_nowait() + if data is sentinel: + break + yield data + except queue.Empty: + pass + + def generate_inner( + self, + prompt: str, + num_samples: int, # if 0, we compile the model + max_new_tokens: int, + speculate_k: int, + temperature: float, + top_k: int, + interactive: bool, + q: queue.Queue, + sentinel: object, + ): + tokenizer_path = self.checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + + is_chat = "chat" in str(self.checkpoint_path) + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + encoded = encode_tokens(tokenizer, prompt, bos=True, device=self.device) + prompt_length = encoded.size(0) + + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain( + self.model.parameters(), self.model.buffers() + ) + ] + ) + aggregate_metrics = { + "tokens_per_sec": [], + "accept_counts": [], + } + + is_speculative = self.draft_checkpoint_path is not None + if self.compile_model: + generate.decode_one_token = self.decode_one_token + + if is_speculative: + generate.model_forward = self.model_forward + + if self.compile_prefill: + generate.prefill = self.prefill + + start = -1 if self.compile_model else 0 + for i in range(start, num_samples): + torch.cuda.synchronize() + + if i == 0: + print(f"Starting inference for prompt = '{prompt}'") + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode(".")[0] + done_generating = False + + def callback(x): + nonlocal done_generating + if done_generating: + return + + xlist = [ + item + for sublist in [x.tolist()] + for item in ( + sublist if isinstance(sublist, list) else [sublist] + ) + ] + buffer.append(tokenizer.decode([period_id] + xlist)[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + + if len(buffer) == 4 or done_generating: + q.put("".join(buffer)) + buffer.clear() + + else: + + def callback(x): + return x + + t0 = time.perf_counter() + + try: + y, metrics = generate.generate( + self.model, + encoded, + max_new_tokens, + interactive=interactive, + draft_model=self.draft_model, + speculate_k=speculate_k, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + except Exception as e: + print("Exception encountered during inference", e) + break + + aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) + + if i == -1: + print( + f"Model compilation time: {time.perf_counter() - t0:.2f} seconds" + ) + continue + + torch.cuda.synchronize() + t = time.perf_counter() - t0 + + if not interactive: + generated = tokenizer.decode(y.tolist()) + q.put(generated) + else: + q.put("\n") + + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) + print( + f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" + ) + + if is_speculative: + counts_aggregated = [ + sum(i) for i in zip(*aggregate_metrics["accept_counts"]) + ] + acceptance_probs = [ + i / sum(counts_aggregated) for i in counts_aggregated + ] + print(f"Acceptance probs: {acceptance_probs}") + print( + f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}" + ) + + if num_samples > 0: + print( + f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}" + ) + print( + f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB" + ) + + q.put(sentinel) + + +@stub.local_entrypoint() +def main( + # Lookup an already deployed model. If False, we'll deploy a new one + # constructed from the following args. + lookup_existing: bool = False, + # Model construction args. + compile_model: bool = True, # Compile the model through pytorch (makes for slower cold starts but much faster inference). + compile_prefill: bool = False, # Compile the prefill function (only used if compile_model is True, and ). + use_base_model: bool = False, # Use the base model (instead of the int8 quantized one). + use_speculative_sampling: bool = False, # Use speculative sampling. + # Inference args. + prompt: str = "", # Input prompt. + num_samples: int = 1, # How many responses to generate for each prompt. + max_new_tokens: int = 100, # Size of each generated response. + speculate_k: int = 5, # Speculative execution depth. + temperature: float = 0.8, # Temperature for sampling. + top_k: int = 200, # Top-k for sampling. + interactive: bool = True, # Whether to stream response. +): + if lookup_existing: + fn = Function.lookup("gpt-fast", "Model.generate") + else: + fn = Model( + compile_model=compile_model, + compile_prefill=compile_prefill, + use_base_model=use_base_model, + use_speculative_sampling=use_speculative_sampling, + ).generate + + prompts = [prompt] + if not prompt: + prompts = [ + "Implement fibonacci in python.", + "Write a Rust function that performs binary exponentiation.", + "How do I allocate memory in C?", + ] + + for prompt in prompts: + for generated in fn.remote_gen( + prompt=prompt, + num_samples=num_samples, + max_new_tokens=max_new_tokens, + speculate_k=speculate_k, + temperature=temperature, + top_k=top_k, + interactive=interactive, + ): + print(generated, end="") + + +app = Stub("gpt-fast-app", image=Image.debian_slim()) + + +@app.function( + mounts=[ + Mount.from_local_dir( + Path(__file__).parent.parent / "llm-frontend", + remote_path="/assets", + ), + ], + allow_concurrent_inputs=10, + timeout=10 * 60, +) +@asgi_app(label="gpt-fast-app") +def modal_app(): + import json + from urllib.parse import unquote + + import fastapi + import fastapi.staticfiles + from fastapi.responses import StreamingResponse + + web_app = fastapi.FastAPI() + + @web_app.get("/model") + async def model(): + return {"name": "Llama-2-7b-chat-hf"} + + @web_app.get("/stats") + async def stats(): + stats = await Function.lookup( + "gpt-fast", "Model.generate" + ).get_current_stats.aio() + return { + "backlog": stats.backlog, + "num_total_runners": stats.num_total_runners, + } + + @web_app.get("/completion/{question}") + async def completion(question: str): + async def generate(): + fn = Function.lookup("gpt-fast", "Model.generate") + for generated in fn.remote_gen(unquote(question)): + yield f"data: {json.dumps(dict(text=generated), ensure_ascii=False)}\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream") + + web_app.mount( + "/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True) + ) + return web_app diff --git a/06_gpu_and_ml/gpt-fast/model.py b/06_gpu_and_ml/gpt-fast/model.py new file mode 100644 index 000000000..5bb9ee090 --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/model.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict( + block_size=16384, + vocab_size=32000, + n_layer=32, + dim=4096, + rope_base=1000000, + ), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict( + n_layer=48, + n_head=64, + dim=8192, + vocab_size=32000, + n_local_heads=8, + intermediate_size=22016, + rope_base=1000000, + ), # CodeLlama-34B-Python-hf + "70B": dict( + n_layer=80, + n_head=64, + dim=8192, + n_local_heads=8, + intermediate_size=28672, + ), +} + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_length, + self.config.n_local_heads, + head_dim, + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones( + self.max_seq_length, self.max_seq_length, dtype=torch.bool + ) + ) + + def forward( + self, idx: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + assert input_pos is not None, "Input positions must be provided" + + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention( + self.attention_norm(x), freqs_cis, mask, input_pos + ) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = ( + config.n_head + 2 * config.n_local_heads + ) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = self.apply_rotary_emb(q, freqs_cis) + k = self.apply_rotary_emb(k, freqs_cis) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.kv_cache is not None and input_pos is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + def apply_rotary_emb(self, x: Tensor, freqs_cis: Tensor) -> Tensor: + # xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + xshaped = x.float().reshape(x.size(0), x.size(1), x.size(2), -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] + - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + # return x_out2.type_as(x) + return x_out2.to(x.dtype) + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt( + torch.mean(x * x, dim=-1, keepdim=True) + self.eps + ) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) diff --git a/06_gpu_and_ml/gpt-fast/quantize.py b/06_gpu_and_ml/gpt-fast/quantize.py new file mode 100644 index 000000000..4e67987e6 --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/quantize.py @@ -0,0 +1,824 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sentencepiece import SentencePieceProcessor + +if TYPE_CHECKING: + from torchsnapshot import StateDict + + from .GPTQ import MultiInput + +try: + from .GPTQ import GenericGPTQRunner, InputRecorder, lm_eval +except ImportError: + pass + +from .model import Transformer + +##### Quantization Primitives ###### + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros( + min_val_neg.size(), dtype=torch.int64, device=device + ) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams( + w, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams( + w, scales, zeros, n_bit, groupsize + ) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)) + .mul(scales) + .add(zeros) + .reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + + +class GPTQQuantHandler(QuantHandler): + """ + This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + + def __init__(self): + assert self.mod is not None + assert self.get_qparams_func is not None + assert self.quantize_func is not None + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None + assert self.make_names_and_values_dict_func is not None + + @staticmethod + def get_inputs( + model, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "MultiInput": + input_recorder = InputRecorder( + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs, + ) + task_dict = lm_eval.tasks.get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + lm_eval.evaluator.evaluate( + input_recorder, + task_dict, + limit=calibration_limit, + ) + inputs = input_recorder.get_recorded_inputs() + print(f"Obtained {len(inputs[0].values)} calibration samples") + return inputs + + @torch.no_grad() + def create_quantized_state_dict( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "StateDict": + inputs = GPTQQuantHandler.get_inputs( + self.mod, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + self.mod, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, + self.quantize_func, + self.dequantize_func, + self.combine_qparams_list_func, + self.make_names_and_values_dict_func, + self.skip_layer_func, + ) + + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def convert_for_runtime(self) -> "nn.Module": + pass + + +##### Weight-only int8 per-channel quantized code ###### + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer( + "scales", torch.ones(out_features, dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +##### weight only int4 per channel groupwise quantized code ###### + + +def prepare_int4_weight_and_scales_and_zeros( + weight_bf16, groupsize, inner_k_tiles +): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4( + x, weight_int4pack, scales_and_zeros, out_features, groupsize +): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k( + child.in_features, groupsize, inner_k_tiles + ): + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + ), + ) + elif padding: + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + ), + ) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding: + import torch.nn.functional as F + from model import find_multiple + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + ( + weight_int4pack, + scales_and_zeros, + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to("cuda"), + self.groupsize, + self.inner_k_tiles, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( + "cpu" + ) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4( + self.mod, self.groupsize, self.inner_k_tiles, self.padding + ) + return self.mod + + +class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + from model import find_multiple + + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) + self.quantize_func = ( + lambda w, qparams: group_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], 4, groupsize + ) + ) + self.dequantize_func = ( + lambda q, qparams: group_dequantize_tensor_from_qparams( + q, qparams[0], qparams[1], 4, groupsize + ).float() + ) + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] + # skip unless padding=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k( + linear_weight.shape[-1], groupsize, inner_k_tiles + ) + or padding + ) + + # we need to do the padding here, both for q and the qparams if necessary + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + scales_and_zeros = pack_scales_and_zeros(*qparams) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad( + scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 + ) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + def convert_for_runtime(self): + replace_linear_int4( + self.mod, self.groupsize, self.inner_k_tiles, self.padding + ) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + from model import find_multiple + + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales_and_zeros", + torch.empty( + (in_features // groupsize, out_features, 2), + dtype=torch.bfloat16, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + + input = F.pad( + input, pad=(0, self.in_features - self.origin_in_features) + ) + return linear_forward_int4( + input, + self.weight, + self.scales_and_zeros, + self.out_features, + self.groupsize, + ) + + +def quantize( + checkpoint_path: Path = Path( + "checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" + ), + mode: str = "int8", + # following arguments only available when setting int4 quantization. + groupsize: int = 128, + # following arguments only used for GPTQ + calibration_tasks: list = ["hellaswag"], + calibration_limit: int = 1000, + calibration_seq_length: int = 100, + pad_calibration_inputs: bool = False, + percdamp: float = 0.01, + blocksize: int = 128, + label: str = "", +) -> None: + assert checkpoint_path.is_file(), checkpoint_path + + device = "cpu" + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace(".pth", f"{label}int8.pth") + + elif mode == "int4": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" + ) + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace( + ".pth", f"{label}int4.g{groupsize}.pth" + ) + + elif mode == "int4-gptq": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ..." + ) + quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + + quantized_state_dict = quant_handler.create_quantized_state_dict( + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace( + ".pth", f"{label}int4-gptq.g{groupsize}.pth" + ) + else: + raise ValueError( + f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" + ) + + quantize_path = dir_name / new_base_name + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink( + missing_ok=True + ) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + return + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Quantize a model.") + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Path to the model checkpoint to be quantized.", + ) + parser.add_argument( + "--mode", + "-q", + type=str, + default="int8", + choices=["int8", "int4", "int4-gptq"], + help="type of quantization to perform", + ) + parser.add_argument( + "--groupsize", + type=int, + default=32, + help="Group size for int4 quantization.", + ) + parser.add_argument( + "--calibration_tasks", + type=str, + nargs="+", + default=["hellaswag"], + help="tasks to do gptq calibration on, if doing gptq", + ) + parser.add_argument( + "--calibration_limit", + type=int, + default=1000, + help="number of samples to use for gptq calibration", + ) + parser.add_argument( + "--calibration_seq_length", + type=int, + default=100, + help="length of sequences to use for gptq calibration", + ) + parser.add_argument( + "--pad_calibration_inputs", + type=bool, + default=False, + help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower", + ) + parser.add_argument( + "--percdamp", type=float, default=0.01, help="gptq percentage dampening" + ) + parser.add_argument( + "--blocksize", type=int, default=128, help="blocksize for gptq" + ) + parser.add_argument( + "--label", type=str, default="_", help="label to add to output filename" + ) + + args = parser.parse_args() + quantize( + args.checkpoint_path, + args.mode, + args.groupsize, + args.calibration_tasks, + args.calibration_limit, + args.calibration_seq_length, + args.pad_calibration_inputs, + args.percdamp, + args.blocksize, + args.label, + ) diff --git a/06_gpu_and_ml/gpt-fast/tp.py b/06_gpu_and_ml/gpt-fast/tp.py new file mode 100644 index 000000000..b7a187f2d --- /dev/null +++ b/06_gpu_and_ml/gpt-fast/tp.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed import _functional_collectives as funcol + +from .model import Attention, FeedForward, Transformer +from .quantize import WeightOnlyInt4Linear + + +def _get_rank() -> int: + return int(os.environ.get("LOCAL_RANK", "0")) + + +def is_local(): + return _get_rank() == 0 + + +def local_break(): + if is_local(): + breakpoint() + dist.barrier() + + +def _get_world_size() -> int: + return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) + + +def maybe_init_dist() -> Optional[int]: + try: + # provided by torchrun + rank = _get_rank() + world_size = _get_world_size() + + if world_size < 2: + # too few gpus to parallelize, tp is no-op + return None + except KeyError: + # not run via torchrun, no-op + return None + + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return rank + + +def _apply_tp_linear( + linear: nn.Linear, style: str, weight_splits: List[int] = [] +) -> None: + rank = _get_rank() + world_size = _get_world_size() + + # Linear's weight matrix is transposed, and is of shape + # (linear.out_features, linear.in_features) + dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")} + assert style in dim_lookup + shard_dim, size_attr = dim_lookup[style] + + # ensure we can shard evenly + assert getattr(linear, size_attr) % world_size == 0 + + def shard(x, dim): + assert x.size(dim=dim) % world_size == 0 + return torch.tensor_split(x, world_size, dim=dim)[rank] + + def shard_qkv(qkv, dim, weight_splits): + q, k, v = qkv.split(weight_splits, dim=dim) + q = shard(q, dim) + k = shard(k, dim) + v = shard(v, dim) + return torch.cat((q, k, v), dim=dim) + + # shard + if weight_splits: + # attention + assert len(weight_splits) == 3 + + if isinstance(linear, WeightOnlyInt4Linear): + sharded_weight = shard_qkv( + linear.weight, shard_dim, [i // 8 for i in weight_splits] + ) + linear.scales_and_zeros = shard_qkv( + linear.scales_and_zeros, 1 - shard_dim, weight_splits + ) + else: + sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard_qkv(linear.scales, 0, weight_splits) + else: + sharded_weight = shard(linear.weight, shard_dim) + if isinstance(linear, WeightOnlyInt4Linear): + linear.scales_and_zeros = shard( + linear.scales_and_zeros, 1 - shard_dim + ) + if style == "rowwise": + assert ( + linear.scales_and_zeros.shape[0] * 32 + == sharded_weight.shape[1] + * sharded_weight.shape[2] + * sharded_weight.shape[3] + ) + assert ( + linear.scales_and_zeros.shape[1] + == sharded_weight.shape[0] * 8 + ) + if hasattr(linear, "scales") and style == "colwise": + linear.scales = shard(linear.scales, 0) + + # local_break() + linear.weight = nn.Parameter(sharded_weight, requires_grad=False) + setattr(linear, size_attr, getattr(linear, size_attr) // world_size) + + # shape info should still be synced + # assert linear.weight.shape == (linear.out_features, linear.in_features) + + +def _apply_tp_ffn(mlp: FeedForward) -> None: + assert hasattr(mlp, "w1") + assert hasattr(mlp, "w3") + assert hasattr(mlp, "w2") + + _apply_tp_linear(mlp.w1, "colwise") + _apply_tp_linear(mlp.w3, "colwise") + _apply_tp_linear(mlp.w2, "rowwise") + + world_size = _get_world_size() + mlp.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output, "sum", list(range(world_size)) + ) + ) + + +def _apply_tp_attn(attn: Attention) -> None: + assert hasattr(attn, "wqkv") + assert hasattr(attn, "wo") + + kv_size = attn.n_local_heads * attn.head_dim + _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) + _apply_tp_linear(attn.wo, "rowwise") + + # overwrite + world_size = _get_world_size() + attn.n_head = attn.n_head // world_size + attn.dim = attn.dim // world_size + attn.head_dim = attn.dim // attn.n_head + attn.n_local_heads = attn.n_local_heads // world_size + + attn.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output[0], "sum", list(range(world_size)) + ) + ) + + +def _apply_tp_Transformer(Transformer: Transformer) -> None: + # overwrite config before Transformer.setup_cache is called + world_size = _get_world_size() + Transformer.config.n_head = Transformer.config.n_head // world_size + Transformer.config.dim = Transformer.config.dim // world_size + Transformer.config.n_local_heads = ( + Transformer.config.n_local_heads // world_size + ) + + +def apply_tp(model: Transformer) -> None: + _apply_tp_Transformer(model) + for block in model.layers: + # Apply to MLP + _apply_tp_ffn(block.feed_forward) + _apply_tp_attn(block.attention) diff --git a/06_gpu_and_ml/llm-frontend/index.html b/06_gpu_and_ml/llm-frontend/index.html index fffef7a91..20babdbbd 100644 --- a/06_gpu_and_ml/llm-frontend/index.html +++ b/06_gpu_and_ml/llm-frontend/index.html @@ -36,10 +36,9 @@ -
- LLaMA 2 70B +
+
-
console.log(error)); }, + getModelMeta() { + fetch("/model") + .then((response) => response.json()) + .then((data) => { + this.modelMeta = { ...data, loaded: true }; + }) + .catch((error) => console.log(error)); + }, }; } - \ No newline at end of file + diff --git a/06_gpu_and_ml/text_generation_inference.py b/06_gpu_and_ml/text_generation_inference.py index 8cc4b6eb9..2b67dc114 100644 --- a/06_gpu_and_ml/text_generation_inference.py +++ b/06_gpu_and_ml/text_generation_inference.py @@ -212,6 +212,10 @@ def app(): web_app = fastapi.FastAPI() + @web_app.get("/model") + async def model(): + return {"name": "Llama-2-70b-chat"} + @web_app.get("/stats") async def stats(): stats = await Model().generate_stream.get_current_stats.aio() diff --git a/10_integrations/streamlit/app.py b/10_integrations/streamlit/app.py index 91fb88281..4132270db 100644 --- a/10_integrations/streamlit/app.py +++ b/10_integrations/streamlit/app.py @@ -3,7 +3,6 @@ # --- import argparse from datetime import datetime, timedelta - from zoneinfo import ZoneInfo # ## Demo Streamlit application.