Skip to content

Commit

Permalink
Merge pull request #45 from ModelTC/save_fix
Browse files Browse the repository at this point in the history
save_fix_json
  • Loading branch information
Harahan authored Aug 24, 2024
2 parents 34db26c + 6e01ac2 commit 6938996
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def preprocess(self):
self.model.get_embed_layers()[0].weight,
):
logger.info('Tie weight! Copy embed_layer for head_layer!')
path = os.path.join(self.config.model.path, 'config.json')
with open(path, 'r') as f:
config = json.load(f)
config['tie_word_embeddings'] = False
with open(path, 'w') as f:
json.dump(config, f, indent=4)
del self.model.get_head_layers()[0].weight
w = self.model.get_embed_layers()[0].weight.clone()
self.model.get_head_layers()[0].weight = nn.Parameter(w)
Expand Down Expand Up @@ -124,3 +118,14 @@ def subset_transform(self, block, subset):
prev_op[0], had_dim=self.head_dim, output=True
)
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)

@torch.no_grad()
def save_model(self, path):
super().save_model(path)
path = os.path.join(path, 'config.json')
with open(path, 'r') as f:
config = json.load(f)
if 'tie_word_embeddings' in config:
config['tie_word_embeddings'] = False
with open(path, 'w') as f:
json.dump(config, f, indent=4)

0 comments on commit 6938996

Please sign in to comment.