Skip to content

Commit

Permalink
Merge pull request #37 from ModelTC/ttt
Browse files Browse the repository at this point in the history
remove redundant code
  • Loading branch information
llmc-reviewer authored Aug 22, 2024
2 parents a657809 + 6fb8773 commit 98b8ec2
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,6 @@ def block_transform(self, block):
logger.info(f'block:{block}')
logger.info(f'End transform the {self.block_idx+1}-th block')

def bake_mean_into_linear(self, linear):
linear_dtype = linear.weight.dtype
W_ = linear.weight.data.double()
linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
linear.weight.data = linear.weight.data.to(linear_dtype)
if linear.bias is not None:
b_ = linear.bias.data.double()
linear.bias.data = b_ - b_.mean()
linear.bias.data = linear.bias.data.to(linear_dtype)

@torch.no_grad()
def subset_transform(self, block, subset):
Expand All @@ -117,7 +108,7 @@ def subset_transform(self, block, subset):
self.rotate_pre_layers(layers, self.Q)
else:
if self.config['model']['type'] in ['Opt', 'StableLm']:
self.bake_mean_into_linear(layers[0])
self.bake_mean_into_fc(layers[0])

if 'is_mlp' in subset and subset['is_mlp']:
self.rotate_post_layers(
Expand Down

0 comments on commit 98b8ec2

Please sign in to comment.