Skip to content

Commit

Permalink
Merge branch 'main' into setup-pep660
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas authored Apr 16, 2024
2 parents da2dc69 + 127788a commit a0ae460
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 228 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Reformat with ruff-format
5a4263f4dc05fe8f78f4111beab9f68a81deeab1

# CHANGELOG: to reverse chron order + mdformat
4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d
511 changes: 289 additions & 222 deletions CHANGELOG.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
"optim.optimizer.MockArgs": False,
}

__version__ = "0.44.0.dev"
__version__ = "0.43.2.dev"
29 changes: 25 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")

if isinstance(weight_format, torch.Tensor):
weight_format = weight_format.item()

# For new weights format storage type, we explicitly check
# if weights_format is on the mapping
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Expected supported weight format - got {weight_format}")
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
Expand Down Expand Up @@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")

weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]

destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)

def _load_from_state_dict(
self,
Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data):
unpacked_dict = json.loads(json_str)

return unpacked_dict


LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def has_ext_modules(self):
return True


setup(version="0.44.0.dev0", packages=find_packages(), distclass=BinaryDistribution)
setup(version="0.43.2.dev0", packages=find_packages(), distclass=BinaryDistribution)

0 comments on commit a0ae460

Please sign in to comment.