Skip to content

Commit

Permalink
ExLlamav2_HF: Convert logits to FP32 (#4310)
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp authored Oct 19, 2023
1 parent c0ffb77 commit ae8cd44
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def __call__(self, *args, **kwargs):
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)

logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()

if is_negative:
self.past_seq_negative = seq_tensor
Expand Down

0 comments on commit ae8cd44

Please sign in to comment.