Skip to content

Commit

Permalink
Merge pull request #67 from gushiqiao/main
Browse files Browse the repository at this point in the history
Roll back tp
  • Loading branch information
helloyongyang authored Sep 5, 2024
2 parents 8b66f07 + 7ad20c0 commit a20a657
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 202 deletions.
58 changes: 17 additions & 41 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@ def __init__(self, model, quant_config, input, config):
self.save_scale = special_config.get('save_scale', False)

@torch.no_grad()
def get_weight_scale_(self, weights, wquantizer):
def get_weight_scale(self, layers_dict):
layers = list(layers_dict.values())
weights = self.collect_layers_weights(layers)
weights = torch.cat(weights, dim=0)
org_shape = weights.shape
wquantizer = get_wquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
)
weights = wquantizer.reshape_tensor(weights)
scale = weights.abs() / weights.abs().amax(dim=1, keepdim=True)
try:
Expand All @@ -41,24 +50,6 @@ def get_weight_scale_(self, weights, wquantizer):
torch.cuda.empty_cache()
return scale

@torch.no_grad()
def get_weight_scale(self, layers_dict, tensor_parallelize_style=None):
layers = list(layers_dict.values())
weights = self.collect_layers_weights(layers, tensor_parallelize_style)
wquantizer = get_wquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
)
if tensor_parallelize_style is None:
return self.get_weight_scale_(weights, wquantizer)
scales = []
for weight_i in weights:
scales.append(self.get_weight_scale_(weight_i, wquantizer))
return scales

@torch.no_grad()
def get_act_scale(self, x):
return x.abs().view(-1, x.shape[-1]).mean(0)
Expand All @@ -72,9 +63,8 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
return org_out

@torch.no_grad()
def search_scale_subset(self, layers_dict, input, inspect_module,
subset_kwargs, tensor_parallelize_style=None):
w_max = self.get_weight_scale(layers_dict, tensor_parallelize_style)
def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs):
w_max = self.get_weight_scale(layers_dict)
# grid search for ratio
best_error = float('inf')
best_scales = None
Expand All @@ -101,26 +91,14 @@ def search_scale_subset(self, layers_dict, input, inspect_module,

ratio = n * 1 / n_grid
if self.trans_version == 'v1':
if tensor_parallelize_style is not None:
raise NotImplementedError(
'tp not yet supported trans_version v1.'
)
scales = (
(x_max.pow(ratio) / w_max.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
if tensor_parallelize_style == 'rowwise':
split_scales = torch.chunk(scales, self.tp, dim=0)
processed_splits = []
for split in split_scales:
split = split / (split.max() * split.min()).sqrt()
processed_splits.append(split)
scales = torch.cat(processed_splits, dim=0)
else:
scales = scales / (scales.max() * scales.min()).sqrt()
scales = scales / (scales.max() * scales.min()).sqrt()
for layer_name in layers_dict:
fc = layers_dict[layer_name]
fc.weight.mul_(scales.view(1, -1))
Expand All @@ -131,8 +109,8 @@ def search_scale_subset(self, layers_dict, input, inspect_module,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
).fake_quant_weight_dynamic(fc.weight.data,
tensor_parallelize_style)
).fake_quant_weight_dynamic(fc.weight.data)

x_tmp = x / scales.view(1, -1)
if not check_w_only(
self.block_idx,
Expand All @@ -147,7 +125,7 @@ def search_scale_subset(self, layers_dict, input, inspect_module,
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(x_tmp, tensor_parallelize_style)
).fake_quant_act_dynamic(x_tmp)
out = inspect_module(x_tmp, **kwargs)

if isinstance(out, tuple):
Expand Down Expand Up @@ -198,7 +176,6 @@ def subset_transform(
input_name,
inspect_module,
subset_kwargs,
tensor_parallelize_style,
):
if not check_do_quant(
self.block_idx,
Expand Down Expand Up @@ -239,8 +216,7 @@ def subset_transform(
return

scale = self.search_scale_subset(
layers_dict, input_feat[input_name], inspect_module,
subset_kwargs, tensor_parallelize_style
layers_dict, input_feat[input_name], inspect_module, subset_kwargs
)

self.apply_scale(scale, prev_op, layers)
Expand Down
110 changes: 17 additions & 93 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,17 +306,13 @@ def block_transform(self, block, input_feat, block_kwargs):
inspect_module = subset['inspect']
inspect_has_kwargs = subset['has_kwargs']
subset_kwargs = block_kwargs if inspect_has_kwargs else {}
tensor_parallelize_style = (
subset['tensor_parallelize_style'] if self.tp > 1 else None
)
self.subset_transform(
layers_dict,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
tensor_parallelize_style,
)
logger.info(f'End transform the {self.block_idx}-th block')

Expand All @@ -328,20 +324,9 @@ def filter_subset(self, subset):
return True

def collect_layers_weights(self, layers, tensor_parallelize_style=None):
if tensor_parallelize_style is None:
weights = []
for _m in layers:
weights.append(_m.weight)
return weights
weights = [[] for _ in range(self.tp)]
weights = []
for _m in layers:
weight = _m.weight
if tensor_parallelize_style == 'colwise':
split_weights = torch.chunk(weight, self.tp, dim=0)
elif tensor_parallelize_style == 'rowwise':
split_weights = torch.chunk(weight, self.tp, dim=1)
for i in range(self.tp):
weights[i].append(split_weights[i])
weights.append(_m.weight)
return weights

@torch.no_grad()
Expand Down Expand Up @@ -481,13 +466,6 @@ def scale_ln_fcs(self, ln, fcs, scales):
def auto_clip(self, block, input_feat, n_sample_token):
# auto clip
for n, m in block.named_modules():
subsets = self.model.get_subsets_in_block(block)
tensor_parallelize_style = None
if self.tp > 1:
for subset in subsets:
if n in subset['layers']:
tensor_parallelize_style = subset['tensor_parallelize_style']
break
if not check_do_quant(
self.block_idx, n, self.mix_bits_map, self.quantizer_mix_bits
):
Expand Down Expand Up @@ -515,7 +493,6 @@ def auto_clip(self, block, input_feat, n_sample_token):
m.weight,
inputs,
n_sample_token=n_sample_token,
tensor_parallelize_style=tensor_parallelize_style,
)

dist.all_reduce(max_val, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -595,18 +572,30 @@ def get_clip_factor(self, layer, min_val, max_val, layer_name):
return up_factor, low_factor

@torch.no_grad()
def auto_clip_layer_origin(
def auto_clip_layer(
self,
layer_name,
w,
input,
wquantizer,
group_size,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
eps=0.0,
):
assert w.dim() == 2

wquantizer = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
)
if wquantizer.granularity == 'per_group':
group_size = wquantizer.group_size
else:
group_size = w.shape[1]

try:
w = w.reshape(w.shape[0], 1, -1, group_size)
except RuntimeError:
Expand Down Expand Up @@ -735,71 +724,6 @@ def auto_clip_layer_origin(
torch.cuda.empty_cache()
return best_max_val.squeeze(1), best_min_val.squeeze(1)

@torch.no_grad()
def auto_clip_layer(
self,
layer_name,
w,
input,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
eps=0.0,
tensor_parallelize_style=None,
):
assert w.dim() == 2

wquantizer = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
)
if wquantizer.granularity == 'per_group':
group_size = wquantizer.group_size
else:
group_size = w.shape[1]

if tensor_parallelize_style == 'colwise':
split_weights = torch.chunk(w, self.tp, dim=0)
max_val, min_val = [], []
for w in split_weights:
max_val_i, min_val_i = self.auto_clip_layer_origin(
layer_name, w, input,
wquantizer, group_size,
n_grid, max_shrink, n_sample_token, eps,
)
max_val.append(max_val_i)
min_val.append(min_val_i)
max_val = torch.cat(max_val, dim=0)
min_val = torch.cat(min_val, dim=0)
elif tensor_parallelize_style == 'rowwise':
split_weights = torch.chunk(w, self.tp, dim=1)
split_inputs = [[] for _ in range(self.tp)]
for tensor in input:
chunks = torch.chunk(tensor, self.tp, dim=-1)
for i in range(self.tp):
split_inputs[i].append(chunks[i])
max_val, min_val = [], []
for w, input in zip(split_weights, split_inputs):
max_val_i, min_val_i = self.auto_clip_layer_origin(
layer_name, w, input,
wquantizer, group_size,
n_grid, max_shrink, n_sample_token, eps,
)
max_val.append(max_val_i)
min_val.append(min_val_i)
max_val = torch.cat(max_val, dim=-2)
min_val = torch.cat(min_val, dim=-2)
else:
max_val, min_val = self.auto_clip_layer_origin(
layer_name, w, input,
wquantizer, group_size,
n_grid, max_shrink, n_sample_token, eps,
)
return max_val, min_val

def rotate_pre_layers(self, pre_layers, Q):
for layer in pre_layers:
dtype = layer.weight.dtype
Expand Down
Loading

0 comments on commit a20a657

Please sign in to comment.