From 7fed393aa8380f2d7f7c760bbd6a2f68b5caa9ea Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 23 Jul 2024 11:32:50 -0400 Subject: [PATCH] Fix restoration of quant_storage for CPU offloading (#1279) * Fix restoration of quant_storage for CPU offloading * Clarify comment on default quant_storage in Params4bit.from_prequantized() * fix to make quant_storage dynamic based on serialized dtype * delete obsolete comment --------- Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> --- bitsandbytes/nn/modules.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 05f7c04db..40766ad41 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -282,10 +282,13 @@ def from_prequantized( self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type self.bnb_quantized = True + + self.quant_storage = data.dtype + return self def _quantize(self, device): - w = self.data.contiguous().cuda(device) + w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( w, blocksize=self.blocksize, @@ -333,6 +336,7 @@ def to(self, *args, **kwargs): blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, + quant_storage=self.quant_storage, ) return new_param @@ -450,7 +454,7 @@ def forward(self, x: torch.Tensor): # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) + self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) self.weight.quant_state = self.quant_state else: print(