From 617fd2a849fe3d6f94b5ce0c153c7d76f2cf4438 Mon Sep 17 00:00:00 2001 From: turboderp Date: Sun, 9 Jul 2023 01:19:35 +0200 Subject: [PATCH] Fix potential Safetensors memory leak and limit system RAM usage while loading model --- model.py | 88 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/model.py b/model.py index efd7a2b7..c2ea5805 100644 --- a/model.py +++ b/model.py @@ -12,6 +12,7 @@ import cuda_ext import json import math +import gc from enum import Enum class ParsedEnum(Enum): @@ -50,7 +51,6 @@ def __init__(self, model_config_path): self.intermediate_size = read_config["intermediate_size"] self.num_attention_heads = read_config["num_attention_heads"] self.num_hidden_layers = read_config["num_hidden_layers"] - self.num_attention_heads = read_config["num_attention_heads"] self.rms_norm_eps = read_config["rms_norm_eps"] self.vocab_size = read_config["vocab_size"] @@ -75,6 +75,7 @@ def __init__(self, model_config_path): self.alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity. self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map + # Tuning self.matmul_recons_thd = 8 @@ -409,7 +410,7 @@ def forward(self, hidden_states, cache, buffer, lora): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights /= math.sqrt(self.config.head_dim) if buffer.attn_mask is not None: attn_weights = attn_weights + buffer.attn_mask - attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) @@ -577,7 +578,12 @@ def get_layers_devs(self): return sorted(list(set(self.layers))) - def map(self, key, loading = False): + def get_all_devs(self): + + return sorted(list(set(self.layers + [self.lm_head, self.norm, self.embed_tokens]))) + + + def map(self, key): if key.startswith("lm_head."): return self.lm_head if key.startswith("model.embed_tokens."): return self.embed_tokens @@ -629,6 +635,14 @@ def _move_tensor(tensor, new_device, name, config): tensor = tensor.to("cpu") return tensor.to(new_device) +def _layer_dtype_size(key): + if key.endswith(".weight"): return 2 + if key.endswith(".qweight"): return 4 + if key.endswith(".qzeros"): return 4 + if key.endswith(".scales"): return 2 + if key.endswith(".g_idx"): return 0 + raise ValueError("Unrecognized layer: " + key) + class ExLlama: @@ -643,7 +657,7 @@ def __init__(self, config): # Load model weights tensors = {} - with safe_open(self.config.model_path, framework="pt", device="cpu") as f: + with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f: # Begin auto mapping if enabled @@ -662,16 +676,22 @@ def __init__(self, config): if _skip_key(key): continue if key.startswith("model.layers.0."): - tensor = f.get_tensor(key) - decoder_size += tensor.numel() * tensor.element_size() + tensor_slice = f.get_slice(key) + shape = tensor_slice.get_shape() + decoder_size += math.prod(shape) * _layer_dtype_size(key) + del tensor_slice if key.startswith("model.norm."): - tensor = f.get_tensor(key) - norm_size += tensor.numel() * tensor.element_size() + tensor_slice = f.get_slice(key) + shape = tensor_slice.get_shape() + norm_size += math.prod(shape) * _layer_dtype_size(key) + del tensor_slice if key.startswith("lm_head."): - tensor = f.get_tensor(key) - head_size += tensor.numel() * tensor.element_size() + tensor_slice = f.get_slice(key) + shape = tensor_slice.get_shape() + head_size += math.prod(shape) * _layer_dtype_size(key) + del tensor_slice # Assign layers automatically @@ -701,29 +721,47 @@ def __init__(self, config): device_usage += this_layer_size layer_index_device += 1 - # Load tensors, move to device(s) - - max_dq_buffer_size = 0 + # Read tensor list from file + load_keys = [] + with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f: for key in f.keys(): + load_keys.append(key) + + # Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk + + max_dq_buffer_size = 0 + f = None + st_mem = 0 + MAX_ST_MEM = 1024**3 + + for key in load_keys: + + if _skip_key(key): continue + device = self.config.device_map.map(key) - if _skip_key(key): continue + if f is None or st_mem > MAX_ST_MEM: + if f is not None: del f + f = safe_open(self.config.model_path, framework = "pt", device = "cpu") + st_mem = 0 - device = self.config.device_map.map(key, loading = True) - tensor = f.get_tensor(key) + tensor = f.get_tensor(key) + size = tensor.numel() * tensor.element_size() + st_mem += size - if key.endswith(".scales"): tensor = tensor.half() - if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half() - if key == "model.norm.weight": tensor = tensor.half() - if key.endswith(".embed_tokens.weight"): tensor = tensor.half() - if key.endswith(".input_layernorm.weight"): tensor = tensor.half() - if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half() + if key.endswith(".scales"): tensor = tensor.half() + if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half() + if key == "model.norm.weight": tensor = tensor.half() + if key.endswith(".embed_tokens.weight"): tensor = tensor.half() + if key.endswith(".input_layernorm.weight"): tensor = tensor.half() + if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half() - tensor = tensor.to(device, non_blocking = True) + tensor = tensor.to(device, non_blocking = True) + if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8) - if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8) + tensors[key] = tensor - tensors[key] = tensor + del f # Head