Skip to content

Commit

Permalink
Merge branch 'turboderp:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
CoffeeVampir3 authored Jul 9, 2023
2 parents 2caf806 + 617fd2a commit a140d04
Showing 1 changed file with 63 additions and 25 deletions.
88 changes: 63 additions & 25 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import cuda_ext
import json
import math
import gc
from enum import Enum

class ParsedEnum(Enum):
Expand Down Expand Up @@ -50,7 +51,6 @@ def __init__(self, model_config_path):
self.intermediate_size = read_config["intermediate_size"]
self.num_attention_heads = read_config["num_attention_heads"]
self.num_hidden_layers = read_config["num_hidden_layers"]
self.num_attention_heads = read_config["num_attention_heads"]
self.rms_norm_eps = read_config["rms_norm_eps"]
self.vocab_size = read_config["vocab_size"]

Expand All @@ -75,6 +75,7 @@ def __init__(self, model_config_path):
self.alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity.
self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed
self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map

# Tuning

self.matmul_recons_thd = 8
Expand Down Expand Up @@ -409,7 +410,7 @@ def forward(self, hidden_states, cache, buffer, lora):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights /= math.sqrt(self.config.head_dim)
if buffer.attn_mask is not None: attn_weights = attn_weights + buffer.attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16).to(query_states.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)

Expand Down Expand Up @@ -577,7 +578,12 @@ def get_layers_devs(self):
return sorted(list(set(self.layers)))


def map(self, key, loading = False):
def get_all_devs(self):

return sorted(list(set(self.layers + [self.lm_head, self.norm, self.embed_tokens])))


def map(self, key):

if key.startswith("lm_head."): return self.lm_head
if key.startswith("model.embed_tokens."): return self.embed_tokens
Expand Down Expand Up @@ -629,6 +635,14 @@ def _move_tensor(tensor, new_device, name, config):
tensor = tensor.to("cpu")
return tensor.to(new_device)

def _layer_dtype_size(key):
if key.endswith(".weight"): return 2
if key.endswith(".qweight"): return 4
if key.endswith(".qzeros"): return 4
if key.endswith(".scales"): return 2
if key.endswith(".g_idx"): return 0
raise ValueError("Unrecognized layer: " + key)


class ExLlama:

Expand All @@ -643,7 +657,7 @@ def __init__(self, config):
# Load model weights

tensors = {}
with safe_open(self.config.model_path, framework="pt", device="cpu") as f:
with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f:

# Begin auto mapping if enabled

Expand All @@ -662,16 +676,22 @@ def __init__(self, config):
if _skip_key(key): continue

if key.startswith("model.layers.0."):
tensor = f.get_tensor(key)
decoder_size += tensor.numel() * tensor.element_size()
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
decoder_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice

if key.startswith("model.norm."):
tensor = f.get_tensor(key)
norm_size += tensor.numel() * tensor.element_size()
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
norm_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice

if key.startswith("lm_head."):
tensor = f.get_tensor(key)
head_size += tensor.numel() * tensor.element_size()
tensor_slice = f.get_slice(key)
shape = tensor_slice.get_shape()
head_size += math.prod(shape) * _layer_dtype_size(key)
del tensor_slice

# Assign layers automatically

Expand Down Expand Up @@ -701,29 +721,47 @@ def __init__(self, config):
device_usage += this_layer_size
layer_index_device += 1

# Load tensors, move to device(s)

max_dq_buffer_size = 0
# Read tensor list from file

load_keys = []
with safe_open(self.config.model_path, framework = "pt", device = "cpu") as f:
for key in f.keys():
load_keys.append(key)

# Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk

max_dq_buffer_size = 0
f = None
st_mem = 0
MAX_ST_MEM = 1024**3

for key in load_keys:

if _skip_key(key): continue
device = self.config.device_map.map(key)

if _skip_key(key): continue
if f is None or st_mem > MAX_ST_MEM:
if f is not None: del f
f = safe_open(self.config.model_path, framework = "pt", device = "cpu")
st_mem = 0

device = self.config.device_map.map(key, loading = True)
tensor = f.get_tensor(key)
tensor = f.get_tensor(key)
size = tensor.numel() * tensor.element_size()
st_mem += size

if key.endswith(".scales"): tensor = tensor.half()
if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half()
if key == "model.norm.weight": tensor = tensor.half()
if key.endswith(".embed_tokens.weight"): tensor = tensor.half()
if key.endswith(".input_layernorm.weight"): tensor = tensor.half()
if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half()
if key.endswith(".scales"): tensor = tensor.half()
if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half()
if key == "model.norm.weight": tensor = tensor.half()
if key.endswith(".embed_tokens.weight"): tensor = tensor.half()
if key.endswith(".input_layernorm.weight"): tensor = tensor.half()
if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half()

tensor = tensor.to(device, non_blocking = True)
tensor = tensor.to(device, non_blocking = True)
if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8)

if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, tensor.numel() * 8)
tensors[key] = tensor

tensors[key] = tensor
del f

# Head

Expand Down

0 comments on commit a140d04

Please sign in to comment.