From 64e3a17b9f24315a97fa3a2465529dc253125cde Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 04:23:41 +0000 Subject: [PATCH 01/10] Supermask Tensor --- benchmark.py | 62 ++++--- requirements.txt | 1 + supermask_ts.py | 436 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 479 insertions(+), 20 deletions(-) create mode 100644 requirements.txt create mode 100644 supermask_ts.py diff --git a/benchmark.py b/benchmark.py index b299198..4c70558 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,4 +1,5 @@ import os +import functools import time import sys import warnings @@ -13,6 +14,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear +from supermask_ts import apply_supermask_ts, SupermaskTensor def apply_sparsity(model): @@ -22,13 +24,18 @@ def apply_sparsity(model): def apply_bsr(model): - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) and "mlp" in name: - try: - module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) - print(f"Converted {name} to bsr format.") - except ValueError as e: - print(f"Unable to convert weight of {name} to bsr format: {e}") + for name, param in model.named_parameters(): + if isinstance(param, SupermaskTensor): + try: + setattr(model, name, to_bsr(param.data, args.bsr)) + print(f"Converted SupermaskTensor {name} to bsr format.") + except ValueError: + # Fall back to strided + setattr(model, name, param.data.to_strided()) + print(f"Converted SupermaskTensor {name} to strided format.") + # for name, module in model.named_modules(): + # if isinstance(module, torch.nn.Linear) and "mlp" in name: + # module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) def to_bsr(tensor, blocksize): @@ -75,31 +82,46 @@ def main(args): print("Creating model") model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) - apply_supermask( + # apply_supermask( + # model, + # linear_sparsity=args.sparsity_linear, + # linear_sp_tilesize=args.sp_linear_tile_size, + # conv1x1_sparsity=args.sparsity_conv1x1, + # conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, + # conv_sparsity=args.sparsity_conv, + # conv_sp_tilesize=args.sp_conv_tile_size, + # skip_last_layer_sparsity=args.skip_last_layer_sparsity, + # skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + # device=device, + # verbose=True, + # ) + assert args.sparsity_conv1x1 == 0 + assert args.sparsity_conv == 0 + scaler = torch.cuda.amp.GradScaler() if args.amp else None + model_without_ddp = model + model.to(device) + if args.bfloat16: + print("Using bfloat16") + model = model.to(torch.bfloat16) + apply_supermask_ts( model, linear_sparsity=args.sparsity_linear, linear_sp_tilesize=args.sp_linear_tile_size, - conv1x1_sparsity=args.sparsity_conv1x1, - conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, - conv_sparsity=args.sparsity_conv, - conv_sp_tilesize=args.sp_conv_tile_size, skip_last_layer_sparsity=args.skip_last_layer_sparsity, skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - device=device, verbose=True, ) - model.to(device) - scaler = torch.cuda.amp.GradScaler() if args.amp else None - model_without_ddp = model - if args.bfloat16: - print("Using bfloat16") - model = model.to(torch.bfloat16) if args.bsr and not args.sparsify_weights: raise ValueError("--bsr can only be used when --sparsify_weights is also specified.") + # if args.sparsify_weights: + # apply_sparsity(model) + # verify_sparsity(model) + # if args.bsr: + # apply_bsr(model) if args.sparsify_weights: apply_sparsity(model) - verify_sparsity(model) + # verify_sparsity(model) if args.bsr: apply_bsr(model) image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=torch.bfloat16 if args.bfloat16 else None, device=device) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9a635b9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +scipy diff --git a/supermask_ts.py b/supermask_ts.py new file mode 100644 index 0000000..3b40a1b --- /dev/null +++ b/supermask_ts.py @@ -0,0 +1,436 @@ +import torch.nn as nn +import math +import torch +from torch.autograd import Variable +import torch.nn.functional as F +from scipy.linalg import hadamard +import numpy as np + + +# original supermask +scores_min=None +scores_max=9e9 +uniform_init_01 = False + +# adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] +# scores_min=0. +# scores_max=1. +# uniform_init_01 = True + +def percentile(t, q): + """Return the value that is larger than q% of t""" + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + return t.view(-1).kthvalue(k).values.item() + + +def to_bsr(tensor, blocksize=256): + if tensor.ndim != 2: + print("Tensor is not 2D, skipping BSR conversion.") + return tensor + + if tensor.size(0) % blocksize or tensor.size(1) % blocksize: + print("Tensor dimensions are not divisible by blocksize, skipping BSR conversion.") + return tensor + + try: + converted_tensor = tensor.to_sparse_bsr(blocksize=blocksize) + print(f"Converted tensor to BSR format with blocksize: {blocksize}") + return converted_tensor + except ValueError as e: + print(f"Unable to convert tensor to BSR format: {e}") + return tensor + + +class GetSubnet(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, scores, zeros, ones, sparsity): + scores.clamp_(min=scores_min,max=scores_max) + k_val = percentile(scores, sparsity*100) + return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + @staticmethod + def backward(ctx, g): + return g, None, None, None + +from typing import Dict, Tuple, Any +SUPERMASK_OPS_TABLE: Dict[Any, Any] = {} + +def implements(aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + SUPERMASK_OPS_TABLE[op] = func + return func + + return decorator + +@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] + + +# weight, scores, shift, scale should be parameters +# that can be trained +class SupermaskTensor(torch.Tensor): + + def __new__( + cls, + weight: torch.Tensor, + scores: torch.Tensor, + sparsity: float, + scale: torch.Tensor, + shift: torch.Tensor, + tile_size: int): + supermask_tensor = torch.Tensor._make_wrapper_subclass( + cls, + weight.shape, + weight.stride(), + weight.storage_offset(), + dtype=weight.dtype, + device=weight.device, + requires_grad=weight.requires_grad, + ) + return supermask_tensor + + def __init__( + self, + weight: torch.Tensor, + scores: torch.Tensor, + sparsity: float, + scale: torch.Tensor, + shift: torch.Tensor, + tile_size: int): + self.weight = weight + self.scores = scores + self.sparsity = sparsity + self.scale = scale + self.shift = shift + self.tile_size = tile_size + + def get_mask(self): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) + + if self.tile_size != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.tile_size, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + def to_strided(self): + subnet = self.get_mask() + return (self.weight*self.scale+self.shift) * subnet + + def to_sparse_bsr(self, blocksize): + return self.to_strided().to_sparse_bsr(blocksize) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in SUPERMASK_OPS_TABLE: + return SUPERMASK_OPS_TABLE[func](func, args, kwargs) + print("func: ", func) + return NotImplemented + +def to_supermask_tensor(weight, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, tile_size): + # initialize the scores + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in weight.size()])) + if sparsity > max_sparsity: + print( + f"reducing sparsity from {sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {weight.size()} and tile size {tile_size})" + ) + sparsity = max_sparsity + scores = torch.empty([max(1, int(math.ceil(wn / tile_size))) for wn in weight.size()], device=weight.device) + nn.init.uniform_(scores) if uniform_init_01 else nn.init.kaiming_uniform_(scores, a=math.sqrt(5)) + + # the shift and the scale are transformation parameters + # the actually used weights = self.weight*self.scale+self.shift + # the transformation is activated only for quantized weights + shift = torch.tensor([0.], requires_grad=False, device=weight.device) + scale = torch.tensor([1.], requires_grad=False, device=weight.device) + + assert bitwidth is None + + # self.weight.requires_grad = not fixed_weight + + return SupermaskTensor(weight, + scores, + sparsity, + scale, + shift, + tile_size) + +def apply_supermask_ts( + model, + linear_sparsity=0.0, + linear_sp_tilesize=1, + skip_last_layer_sparsity=False, + skip_first_transformer_sparsity=False, + verbose=False, +): + for n, m in model.named_modules(): + if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): + m.weight = torch.nn.Parameter(to_supermask_tensor(m.weight, + linear_sparsity, + False, + False, + None, + None, + None, + linear_sp_tilesize)) + + +class SupermaskLinear(nn.Linear): + """Supermask class for Linear layer""" + def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): + tile_size = kwargs.pop("tile_size", 1) + super(SupermaskLinear, self).__init__(*args, **kwargs) + # initialize the scores + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) + self.sparsity = sparsity + if self.sparsity > max_sparsity: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" + ) + self.sparsity = max_sparsity + self.tile_size = tile_size + self.sparsify_weights = False + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # the shift and the scale are transformation parameters + # the actually used weights = self.weight*self.scale+self.shift + # the transformation is activated only for quantized weights + self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) + self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) + + with torch.no_grad(): + # if bitwidth is None, then use floating point values in self.weight + # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) + # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 + # these quantized values are uniformly distributed + if bitwidth is not None: + weights_max = torch.max(self.weight).item() + weights_min = torch.min(self.weight).item() + least_step = (weights_max-weights_min)/pow(2,bitwidth) + left_bound = weights_min-1e-6 + right_bound = weights_min+least_step+1e-6 + # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; + self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): + self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i + left_bound = right_bound + right_bound += least_step + + self.weight.requires_grad = not fixed_weight + + def get_mask(self): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) + + if self.tile_size != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.tile_size, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + def sparsify_offline(self): + subnet = self.get_mask() + self.weight.data = (self.weight*self.scale+self.shift) * subnet + self.sparsify_weights = True + + def forward(self, x): + if not self.sparsify_weights: + subnet = self.get_mask() + w = (self.weight*self.scale+self.shift) * subnet + else: + w = self.weight + return F.linear(x, w, self.bias) + + +class SupermaskConv2d(nn.Conv2d): + """Supermask class for Conv2d layer""" + def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): + tile_size = kwargs.pop("tile_size", 1) + super(SupermaskConv2d, self).__init__(*args, **kwargs) + # initialize the scores + max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) + self.sparsity = sparsity + if self.sparsity > max_sparsity: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" + ) + self.sparsity = max_sparsity + self.tile_size = tile_size + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # the shift and the scale are transformation parameters + # the actually used weights = self.weight*self.scale+self.shift + # the transformation is activated only for quantized weights + self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) + self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) + + with torch.no_grad(): + # if bitwidth is None, then use floating point values in self.weight + # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) + # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 + # these quantized values are uniformly distributed + if bitwidth is not None: + weights_max = torch.max(self.weight).item() + weights_min = torch.min(self.weight).item() + least_step = (weights_max-weights_min)/pow(2,bitwidth) + left_bound = weights_min-1e-6 + right_bound = weights_min+least_step+1e-6 + # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) + # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; + self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) + self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) + for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): + self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i + left_bound = right_bound + right_bound += least_step + + self.weight.requires_grad = not fixed_weight + + def forward(self, x): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity) + + if self.tile_size != 1: + for i, k in enumerate(self.weight.shape): + # if k == 1: continue + subnet = subnet.repeat_interleave(self.tile_size, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + w = (self.weight*self.scale+self.shift) * subnet + return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) + +@torch.no_grad() +def set_sparsity(modules, sparsity): + """Set the sparsity for supermask layers""" + sm_idx = 0 + for mod in modules: + if isinstance(mod, (SupermaskLinear, SupermaskConv2d)): + mod.sparsity=sparsity[sm_idx] + sm_idx += 1 + print(mod) + print('Sparsity: ', mod.sparsity) + + +def apply_supermask( + model, + linear_sparsity=0.0, + linear_sp_tilesize=1, + conv1x1_sparsity=0.0, + conv1x1_sp_tilesize=1, + conv_sparsity=0.0, + conv_sp_tilesize=1, + skip_last_layer_sparsity=False, + skip_first_transformer_sparsity=False, + device="cuda", + verbose=False, +): + sparsified_modules = {} + + for n, m in model.named_modules(): + # check conditions for skipping sparsity + if skip_last_layer_sparsity and n == "heads.head": + continue + if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: + continue + + # convert 1x1 convolutions + if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): + new_m = SupermaskConv2d( + conv1x1_sparsity, False, False, None, None, None, + m.in_channels, + m.out_channels, + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=m.bias is not None, + padding_mode=m.padding_mode, + device=device, + tile_size=conv1x1_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + # convert all other convolutions (not tested!) + if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): + new_m = SupermaskConv2d( + conv_sparsity, False, False, None, None, None, + m.in_channels, + m.out_channels, + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=m.bias is not None, + padding_mode=m.padding_mode, + device=device, + tile_size=conv_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): + new_m = SupermaskLinear( + linear_sparsity, False, False, None, None, None, + m.in_features, + m.out_features, + bias=m.bias is not None, + device=device, + tile_size=linear_sp_tilesize, + ) + new_m.weight.data.copy_(m.weight.data) + if m.bias is not None: + new_m.bias.data.copy_(m.bias.data) + sparsified_modules[n] = new_m + continue + + # add modules to model + for k, v in sparsified_modules.items(): + sm_name, ch_name = k.rsplit(".", 1) + sm = model.get_submodule(sm_name) + sm.add_module(ch_name, v) + + if verbose: + print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') + + return model From 6702f6445ea378c633c4ca17f802c3eecb662e3c Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 04:54:22 +0000 Subject: [PATCH 02/10] Working but slower --- benchmark.py | 15 +++++++++-- supermask_ts.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/benchmark.py b/benchmark.py index 4c70558..2848d2f 100644 --- a/benchmark.py +++ b/benchmark.py @@ -22,16 +22,23 @@ def apply_sparsity(model): if isinstance(module, SupermaskLinear) and "mlp" in name: module.sparsify_offline() +def set_parameter(model, name, param): + if '.' in name: + names = name.split('.') + set_parameter(getattr(model, names[0]), '.'.join(names[1:]), param) + else: + setattr(model, name, param) + def apply_bsr(model): for name, param in model.named_parameters(): if isinstance(param, SupermaskTensor): try: - setattr(model, name, to_bsr(param.data, args.bsr)) + set_parameter(model, name, torch.nn.Parameter(to_bsr(param.data, args.bsr))) print(f"Converted SupermaskTensor {name} to bsr format.") except ValueError: # Fall back to strided - setattr(model, name, param.data.to_strided()) + set_parameter(model, name, torch.nn.Parameter(param.data.to_strided())) print(f"Converted SupermaskTensor {name} to strided format.") # for name, module in model.named_modules(): # if isinstance(module, torch.nn.Linear) and "mlp" in name: @@ -123,7 +130,11 @@ def main(args): apply_sparsity(model) # verify_sparsity(model) if args.bsr: + print("0 ---") + apply_bsr(model) + print("1 ---") apply_bsr(model) + print("2 ---") image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=torch.bfloat16 if args.bfloat16 else None, device=device) # model = torch.compile(model, mode='max-autotune') print(benchmark_in_ms(10, 100, model, image), file=sys.stderr) diff --git a/supermask_ts.py b/supermask_ts.py index 3b40a1b..436ae96 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -6,6 +6,59 @@ from scipy.linalg import hadamard import numpy as np +def _replace_with_custom_fn_if_matches_filter( + model, + replacement_fn, + filter_fn, + cur_fqn="", +) -> None: + """ + For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True` + """ + if filter_fn(model, cur_fqn[:-1]): + model = replacement_fn(model) + return model + else: + for name, child in model.named_children(): + new_child = _replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, f"{cur_fqn}{name}." + ) + if new_child is not child: + setattr(model, name, new_child) + return model + +def swap_conv2d_1x1_to_linear(model, filter_fn=None): + """ + Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. + """ + + class PermuteSandwich(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + return self.mod(args[0].permute(0, 2, 3, 1)).permute(-0, 3, 1, 2) + + def replace_conv2d_1x1(conv): + assert conv.kernel_size == (1, 1) + lin = torch.nn.Linear( + conv.in_channels, conv.out_channels, bias=(conv.bias is None) + ) + lin.weight = torch.nn.Parameter(conv.weight.squeeze(-1, -2)) + lin.bias = conv.bias + return PermuteSandwich(lin) + + if filter_fn is None: + filter_fn = lambda mod, *args: isinstance( + mod, torch.nn.Conv2d + ) and mod.kernel_size == (1, 1) + + _replace_with_custom_fn_if_matches_filter( + model, replace_conv2d_1x1, filter_fn=filter_fn + ) + # original supermask scores_min=None @@ -144,14 +197,14 @@ def to_supermask_tensor(weight, sparsity, fixed_mask, fixed_weight, bitwidth, tr f"(maximum sparsity for layer with shape {weight.size()} and tile size {tile_size})" ) sparsity = max_sparsity - scores = torch.empty([max(1, int(math.ceil(wn / tile_size))) for wn in weight.size()], device=weight.device) + scores = torch.empty([max(1, int(math.ceil(wn / tile_size))) for wn in weight.size()], device=weight.device, dtype=weight.dtype) nn.init.uniform_(scores) if uniform_init_01 else nn.init.kaiming_uniform_(scores, a=math.sqrt(5)) # the shift and the scale are transformation parameters # the actually used weights = self.weight*self.scale+self.shift # the transformation is activated only for quantized weights - shift = torch.tensor([0.], requires_grad=False, device=weight.device) - scale = torch.tensor([1.], requires_grad=False, device=weight.device) + shift = torch.tensor([0.], requires_grad=False, device=weight.device, dtype=weight.dtype) + scale = torch.tensor([1.], requires_grad=False, device=weight.device, dtype=weight.dtype) assert bitwidth is None @@ -172,6 +225,7 @@ def apply_supermask_ts( skip_first_transformer_sparsity=False, verbose=False, ): + swap_conv2d_1x1_to_linear(model) for n, m in model.named_modules(): if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): m.weight = torch.nn.Parameter(to_supermask_tensor(m.weight, @@ -182,6 +236,17 @@ def apply_supermask_ts( None, None, linear_sp_tilesize)) + if linear_sparsity != 0.0 and isinstance(m, torch.nn.MultiheadAttention): + assert m._qkv_same_embed_dim + m.in_proj_weight = torch.nn.Parameter(to_supermask_tensor(m.in_proj_weight, + linear_sparsity, + False, + False, + None, + None, + None, + linear_sp_tilesize)) + class SupermaskLinear(nn.Linear): From fabfa5b47eb0acb7424b673dca9ecec984fccf97 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 04:56:27 +0000 Subject: [PATCH 03/10] Supporting both --- benchmark.py | 50 +++++++++++++++++-------------------------------- supermask_ts.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/benchmark.py b/benchmark.py index 2848d2f..10f0916 100644 --- a/benchmark.py +++ b/benchmark.py @@ -22,27 +22,10 @@ def apply_sparsity(model): if isinstance(module, SupermaskLinear) and "mlp" in name: module.sparsify_offline() -def set_parameter(model, name, param): - if '.' in name: - names = name.split('.') - set_parameter(getattr(model, names[0]), '.'.join(names[1:]), param) - else: - setattr(model, name, param) - - def apply_bsr(model): - for name, param in model.named_parameters(): - if isinstance(param, SupermaskTensor): - try: - set_parameter(model, name, torch.nn.Parameter(to_bsr(param.data, args.bsr))) - print(f"Converted SupermaskTensor {name} to bsr format.") - except ValueError: - # Fall back to strided - set_parameter(model, name, torch.nn.Parameter(param.data.to_strided())) - print(f"Converted SupermaskTensor {name} to strided format.") - # for name, module in model.named_modules(): - # if isinstance(module, torch.nn.Linear) and "mlp" in name: - # module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and "mlp" in name: + module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) def to_bsr(tensor, blocksize): @@ -89,19 +72,20 @@ def main(args): print("Creating model") model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) - # apply_supermask( - # model, - # linear_sparsity=args.sparsity_linear, - # linear_sp_tilesize=args.sp_linear_tile_size, - # conv1x1_sparsity=args.sparsity_conv1x1, - # conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, - # conv_sparsity=args.sparsity_conv, - # conv_sp_tilesize=args.sp_conv_tile_size, - # skip_last_layer_sparsity=args.skip_last_layer_sparsity, - # skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - # device=device, - # verbose=True, - # ) + if not args.use_ts: + apply_supermask( + model, + linear_sparsity=args.sparsity_linear, + linear_sp_tilesize=args.sp_linear_tile_size, + conv1x1_sparsity=args.sparsity_conv1x1, + conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, + conv_sparsity=args.sparsity_conv, + conv_sp_tilesize=args.sp_conv_tile_size, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + device=device, + verbose=True, + ) assert args.sparsity_conv1x1 == 0 assert args.sparsity_conv == 0 scaler = torch.cuda.amp.GradScaler() if args.amp else None diff --git a/supermask_ts.py b/supermask_ts.py index 436ae96..74770b5 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -6,6 +6,25 @@ from scipy.linalg import hadamard import numpy as np +def set_parameter(model, name, param): + if '.' in name: + names = name.split('.') + set_parameter(getattr(model, names[0]), '.'.join(names[1:]), param) + else: + setattr(model, name, param) + + +def apply_bsr(model): + for name, param in model.named_parameters(): + if isinstance(param, SupermaskTensor): + try: + set_parameter(model, name, torch.nn.Parameter(to_bsr(param.data, args.bsr))) + print(f"Converted SupermaskTensor {name} to bsr format.") + except ValueError: + # Fall back to strided + set_parameter(model, name, torch.nn.Parameter(param.data.to_strided())) + print(f"Converted SupermaskTensor {name} to strided format.") + def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, From 806c3e1bac49c2772e4e367df83cc39d97c68cb4 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 04:59:46 +0000 Subject: [PATCH 04/10] Supporting both --- benchmark.py | 78 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/benchmark.py b/benchmark.py index 10f0916..3ca8cfe 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,5 +1,4 @@ import os -import functools import time import sys import warnings @@ -24,8 +23,12 @@ def apply_sparsity(model): def apply_bsr(model): for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear) and "mlp" in name: - module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) + if isinstance(module, torch.nn.Linear) and "mlp" in name: + try: + module.weight = torch.nn.Parameter(to_bsr(module.weight.data, args.bsr)) + print(f"Converted {name} to bsr format.") + except ValueError as e: + print(f"Unable to convert weight of {name} to bsr format: {e}") def to_bsr(tensor, blocksize): @@ -86,39 +89,48 @@ def main(args): device=device, verbose=True, ) - assert args.sparsity_conv1x1 == 0 - assert args.sparsity_conv == 0 - scaler = torch.cuda.amp.GradScaler() if args.amp else None - model_without_ddp = model - model.to(device) - if args.bfloat16: - print("Using bfloat16") - model = model.to(torch.bfloat16) - apply_supermask_ts( - model, - linear_sparsity=args.sparsity_linear, - linear_sp_tilesize=args.sp_linear_tile_size, - skip_last_layer_sparsity=args.skip_last_layer_sparsity, - skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - verbose=True, - ) + model.to(device) + scaler = torch.cuda.amp.GradScaler() if args.amp else None + model_without_ddp = model + if args.bfloat16: + print("Using bfloat16") + model = model.to(torch.bfloat16) + else: + model.to(device) + scaler = torch.cuda.amp.GradScaler() if args.amp else None + model_without_ddp = model + if args.bfloat16: + print("Using bfloat16") + model = model.to(torch.bfloat16) + assert args.sparsity_conv1x1 == 0 + assert args.sparsity_conv == 0 + apply_supermask_ts( + model, + linear_sparsity=args.sparsity_linear, + linear_sp_tilesize=args.sp_linear_tile_size, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + verbose=True, + ) if args.bsr and not args.sparsify_weights: raise ValueError("--bsr can only be used when --sparsify_weights is also specified.") - # if args.sparsify_weights: - # apply_sparsity(model) - # verify_sparsity(model) - # if args.bsr: - # apply_bsr(model) - if args.sparsify_weights: - apply_sparsity(model) - # verify_sparsity(model) - if args.bsr: - print("0 ---") - apply_bsr(model) - print("1 ---") - apply_bsr(model) - print("2 ---") + if not args.use_ts: + if args.sparsify_weights: + apply_sparsity(model) + verify_sparsity(model) + if args.bsr: + apply_bsr(model) + else: + if args.sparsify_weights: + apply_sparsity(model) + # verify_sparsity(model) + if args.bsr: + print("0 ---") + apply_bsr(model) + print("1 ---") + apply_bsr(model) + print("2 ---") image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=torch.bfloat16 if args.bfloat16 else None, device=device) # model = torch.compile(model, mode='max-autotune') print(benchmark_in_ms(10, 100, model, image), file=sys.stderr) From 05f4bfd3c3224c7f218b43082cd20d4002706f4f Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 05:01:02 +0000 Subject: [PATCH 05/10] Supporting both --- benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark.py b/benchmark.py index 3ca8cfe..7ee4364 100644 --- a/benchmark.py +++ b/benchmark.py @@ -21,6 +21,7 @@ def apply_sparsity(model): if isinstance(module, SupermaskLinear) and "mlp" in name: module.sparsify_offline() + def apply_bsr(model): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and "mlp" in name: From 352c4943859dd87e8a22899036ad69df9d40b1a3 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 05:07:43 +0000 Subject: [PATCH 06/10] Supporting both --- benchmark.py | 8 ++++---- supermask_ts.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmark.py b/benchmark.py index 7ee4364..b46ef10 100644 --- a/benchmark.py +++ b/benchmark.py @@ -13,7 +13,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear -from supermask_ts import apply_supermask_ts, SupermaskTensor +from supermask_ts import apply_supermask_ts, SupermaskTensor, apply_bsr_ts def apply_sparsity(model): @@ -124,13 +124,12 @@ def main(args): apply_bsr(model) else: if args.sparsify_weights: - apply_sparsity(model) # verify_sparsity(model) if args.bsr: print("0 ---") - apply_bsr(model) + apply_bsr_ts(model, args.bsr) print("1 ---") - apply_bsr(model) + apply_bsr_ts(model, args.bsr) print("2 ---") image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=torch.bfloat16 if args.bfloat16 else None, device=device) # model = torch.compile(model, mode='max-autotune') @@ -174,6 +173,7 @@ def get_args_parser(add_help=True): parser.add_argument('--sparsify-weights', action='store_true', help='Apply weight sparsification in evaluation mode') parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') parser.add_argument("--bfloat16", action="store_true", help="Use bfloat16") + parser.add_argument("--use-ts", action="store_true", help="Use Tensor subclass") return parser diff --git a/supermask_ts.py b/supermask_ts.py index 74770b5..0b3bc89 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -14,13 +14,13 @@ def set_parameter(model, name, param): setattr(model, name, param) -def apply_bsr(model): +def apply_bsr_ts(model, blocksize): for name, param in model.named_parameters(): if isinstance(param, SupermaskTensor): try: - set_parameter(model, name, torch.nn.Parameter(to_bsr(param.data, args.bsr))) + set_parameter(model, name, torch.nn.Parameter(param.data.to_sparse_bsr(blocksize))) print(f"Converted SupermaskTensor {name} to bsr format.") - except ValueError: + except RuntimeError: # Fall back to strided set_parameter(model, name, torch.nn.Parameter(param.data.to_strided())) print(f"Converted SupermaskTensor {name} to strided format.") From 3d5304f48b5419f5876d3de984d705641efbc57b Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 05:38:02 +0000 Subject: [PATCH 07/10] Supporting both --- benchmark.py | 18 +++++++++++++----- supermask_ts.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/benchmark.py b/benchmark.py index b46ef10..1c12128 100644 --- a/benchmark.py +++ b/benchmark.py @@ -13,7 +13,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear -from supermask_ts import apply_supermask_ts, SupermaskTensor, apply_bsr_ts +from supermask_ts import apply_supermask_ts, SupermaskTensor, apply_bsr_ts, verify_sparsity_ts, verify_sparsity_ts_bsr def apply_sparsity(model): @@ -124,17 +124,25 @@ def main(args): apply_bsr(model) else: if args.sparsify_weights: - # verify_sparsity(model) + verify_sparsity_ts(model) if args.bsr: print("0 ---") apply_bsr_ts(model, args.bsr) print("1 ---") apply_bsr_ts(model, args.bsr) print("2 ---") + verify_sparsity_ts_bsr(model) + print("3 ---") image = torch.empty(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=torch.bfloat16 if args.bfloat16 else None, device=device) - # model = torch.compile(model, mode='max-autotune') - print(benchmark_in_ms(10, 100, model, image), file=sys.stderr) - return + with torch.no_grad(): + # model = torch.compile(model, mode='max-autotune') + ms = benchmark_in_ms(10, 100, model, image) + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + print(f"{ms}ms {max_memory_allocated_bytes}MB of RAM, {max_memory_allocated_percentage}% of RAM", file=sys.stderr) + return def get_args_parser(add_help=True): diff --git a/supermask_ts.py b/supermask_ts.py index 0b3bc89..ed1a6f7 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -25,6 +25,19 @@ def apply_bsr_ts(model, blocksize): set_parameter(model, name, torch.nn.Parameter(param.data.to_strided())) print(f"Converted SupermaskTensor {name} to strided format.") +def verify_sparsity_ts(model): + for name, param in model.named_parameters(): + if isinstance(param, SupermaskTensor): + total_weights = param.to_strided().numel() + sparse_weights = (param.to_strided() == 0).sum().item() + sparsity_percentage = (sparse_weights / total_weights) * 100 + print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") + +def verify_sparsity_ts_bsr(model): + for name, param in model.named_parameters(): + if param.layout == torch.sparse_bsr: + print(f"ratio: {param.values().numel() / param.numel()}") + def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, @@ -246,6 +259,11 @@ def apply_supermask_ts( ): swap_conv2d_1x1_to_linear(model) for n, m in model.named_modules(): + # check conditions for skipping sparsity + if skip_last_layer_sparsity and n == "heads.head": + continue + if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: + continue if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): m.weight = torch.nn.Parameter(to_supermask_tensor(m.weight, linear_sparsity, From f68fb6585f3bfcbb20601ce80f7a977406815c5a Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 05:41:34 +0000 Subject: [PATCH 08/10] Less code --- supermask_ts.py | 253 ------------------------------------------------ 1 file changed, 253 deletions(-) diff --git a/supermask_ts.py b/supermask_ts.py index ed1a6f7..b0dd7f2 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -283,256 +283,3 @@ def apply_supermask_ts( None, None, linear_sp_tilesize)) - - - -class SupermaskLinear(nn.Linear): - """Supermask class for Linear layer""" - def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskLinear, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.sparsify_weights = False - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) - self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max-weights_min)/pow(2,bitwidth) - left_bound = weights_min-1e-6 - right_bound = weights_min+least_step+1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): - self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def get_mask(self): - subnet = GetSubnet.apply(self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - return subnet - - def sparsify_offline(self): - subnet = self.get_mask() - self.weight.data = (self.weight*self.scale+self.shift) * subnet - self.sparsify_weights = True - - def forward(self, x): - if not self.sparsify_weights: - subnet = self.get_mask() - w = (self.weight*self.scale+self.shift) * subnet - else: - w = self.weight - return F.linear(x, w, self.bias) - - -class SupermaskConv2d(nn.Conv2d): - """Supermask class for Conv2d layer""" - def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskConv2d, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - (1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()])) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})" - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift=nn.Parameter(torch.Tensor(1).fill_(0.), requires_grad=False) - self.scale=nn.Parameter(torch.Tensor(1).fill_(1.), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max-weights_min)/pow(2,bitwidth) - left_bound = weights_min-1e-6 - right_bound = weights_min+least_step+1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift=nn.Parameter(torch.Tensor(1).fill_( 0. if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - self.scale=nn.Parameter(torch.Tensor(1).fill_( 1. if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - for i in range(-int(pow(2,bitwidth-1)),int(pow(2,bitwidth-1))): - self.weight[torch.logical_and(self.weight>left_bound, self.weight<=right_bound)] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def forward(self, x): - subnet = GetSubnet.apply(self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - # if k == 1: continue - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - w = (self.weight*self.scale+self.shift) * subnet - return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) - -@torch.no_grad() -def set_sparsity(modules, sparsity): - """Set the sparsity for supermask layers""" - sm_idx = 0 - for mod in modules: - if isinstance(mod, (SupermaskLinear, SupermaskConv2d)): - mod.sparsity=sparsity[sm_idx] - sm_idx += 1 - print(mod) - print('Sparsity: ', mod.sparsity) - - -def apply_supermask( - model, - linear_sparsity=0.0, - linear_sp_tilesize=1, - conv1x1_sparsity=0.0, - conv1x1_sp_tilesize=1, - conv_sparsity=0.0, - conv_sp_tilesize=1, - skip_last_layer_sparsity=False, - skip_first_transformer_sparsity=False, - device="cuda", - verbose=False, -): - sparsified_modules = {} - - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): - new_m = SupermaskConv2d( - conv1x1_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, False, False, None, None, None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) - - if verbose: - print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') - - return model From 18cddfb4fb3a30310d5bbe5956d9c07b738236f0 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 11 Apr 2024 05:45:54 +0000 Subject: [PATCH 09/10] Less code --- supermask_ts.py | 55 ------------------------------------------------- 1 file changed, 55 deletions(-) diff --git a/supermask_ts.py b/supermask_ts.py index b0dd7f2..3dc47f7 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -38,60 +38,6 @@ def verify_sparsity_ts_bsr(model): if param.layout == torch.sparse_bsr: print(f"ratio: {param.values().numel() / param.numel()}") -def _replace_with_custom_fn_if_matches_filter( - model, - replacement_fn, - filter_fn, - cur_fqn="", -) -> None: - """ - For each `child` in `model`, replaces it with `replacement_fn(child)` - if `filter_fn(child)` is `True` - """ - if filter_fn(model, cur_fqn[:-1]): - model = replacement_fn(model) - return model - else: - for name, child in model.named_children(): - new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}." - ) - if new_child is not child: - setattr(model, name, new_child) - return model - -def swap_conv2d_1x1_to_linear(model, filter_fn=None): - """ - Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. - """ - - class PermuteSandwich(torch.nn.Module): - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, *args): - return self.mod(args[0].permute(0, 2, 3, 1)).permute(-0, 3, 1, 2) - - def replace_conv2d_1x1(conv): - assert conv.kernel_size == (1, 1) - lin = torch.nn.Linear( - conv.in_channels, conv.out_channels, bias=(conv.bias is None) - ) - lin.weight = torch.nn.Parameter(conv.weight.squeeze(-1, -2)) - lin.bias = conv.bias - return PermuteSandwich(lin) - - if filter_fn is None: - filter_fn = lambda mod, *args: isinstance( - mod, torch.nn.Conv2d - ) and mod.kernel_size == (1, 1) - - _replace_with_custom_fn_if_matches_filter( - model, replace_conv2d_1x1, filter_fn=filter_fn - ) - - # original supermask scores_min=None scores_max=9e9 @@ -257,7 +203,6 @@ def apply_supermask_ts( skip_first_transformer_sparsity=False, verbose=False, ): - swap_conv2d_1x1_to_linear(model) for n, m in model.named_modules(): # check conditions for skipping sparsity if skip_last_layer_sparsity and n == "heads.head": From e9c93358b92a9e043e11ff96cabc8691019dd493 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Fri, 12 Apr 2024 17:31:01 +0000 Subject: [PATCH 10/10] Better printing --- supermask_ts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/supermask_ts.py b/supermask_ts.py index 3dc47f7..d47f533 100644 --- a/supermask_ts.py +++ b/supermask_ts.py @@ -36,7 +36,7 @@ def verify_sparsity_ts(model): def verify_sparsity_ts_bsr(model): for name, param in model.named_parameters(): if param.layout == torch.sparse_bsr: - print(f"ratio: {param.values().numel() / param.numel()}") + print(f"{name} ratio: {param.values().numel() / param.numel()}") # original supermask scores_min=None @@ -219,6 +219,7 @@ def apply_supermask_ts( None, linear_sp_tilesize)) if linear_sparsity != 0.0 and isinstance(m, torch.nn.MultiheadAttention): + # continue assert m._qkv_same_embed_dim m.in_proj_weight = torch.nn.Parameter(to_supermask_tensor(m.in_proj_weight, linear_sparsity,