From b8810c74e91c8c68211d9a1816bc878db7197c35 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 1 Feb 2024 23:41:01 -0500 Subject: [PATCH 01/21] .gitignore magic_wand dir --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index b5195629e5cf3..8e46368ad7df4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Dependency repos +magic_wand + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From d56b4c4093b211bc688bec5694bf1fd5eee7a039 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 2 Feb 2024 03:27:05 -0500 Subject: [PATCH 02/21] added 2:4 example (not actually using 2:4 yet\!) --- .../offline_inference_semi_structured_sparse.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 examples/offline_inference_semi_structured_sparse.py diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py new file mode 100644 index 0000000000000..1e21c213ca8b0 --- /dev/null +++ b/examples/offline_inference_semi_structured_sparse.py @@ -0,0 +1,14 @@ +from vllm import LLM, SamplingParams + +model = LLM( + "nm-testing/zephyr-50sparse-24", + sparsity="sparse_w16a16", # If left off, model will be loaded as dense + enforce_eager=True, # Does not work with cudagraphs yet + dtype="float16", + tensor_parallel_size=1, + max_model_len=1024 +) + +sampling_params = SamplingParams(max_tokens=100, temperature=0) +outputs = model.generate("Hello my name is", sampling_params=sampling_params) +outputs[0].outputs[0].text \ No newline at end of file From 1a8bc1c05067fc4fca70184c96f22eee22018c93 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 2 Feb 2024 04:17:36 -0500 Subject: [PATCH 03/21] use only cuda:0 --- examples/offline_inference_semi_structured_sparse.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index 1e21c213ca8b0..fce47ae1d7f40 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -1,4 +1,7 @@ from vllm import LLM, SamplingParams +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 + model = LLM( "nm-testing/zephyr-50sparse-24", @@ -11,4 +14,4 @@ sampling_params = SamplingParams(max_tokens=100, temperature=0) outputs = model.generate("Hello my name is", sampling_params=sampling_params) -outputs[0].outputs[0].text \ No newline at end of file +print(outputs[0].outputs[0].text) \ No newline at end of file From 2c6ff2675bf88497b428be6c8212ab8520124873 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 2 Feb 2024 04:34:02 -0500 Subject: [PATCH 04/21] wip semi_structured_sparse_w16a16 --- ...ffline_inference_semi_structured_sparse.py | 2 +- .../sparsity/semi_structured_sparse_w16a16.py | 97 +++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index fce47ae1d7f40..87757e312f7d5 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -5,7 +5,7 @@ model = LLM( "nm-testing/zephyr-50sparse-24", - sparsity="sparse_w16a16", # If left off, model will be loaded as dense + sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense enforce_eager=True, # Does not work with cudagraphs yet dtype="float16", tensor_parallel_size=1, diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py new file mode 100644 index 0000000000000..1a28fdb143cad --- /dev/null +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from vllm.model_executor.layers.parameters import SparseParameter + + +class SemiStructuredSparseW16A16Config(SparsityConfig): + """Config class for SemiStructuredSparseW16A16. + """ + + def __init__(self) -> None: + # TODO: Add new configs here + pass + + def __repr__(self) -> str: + return "SemiStructuredSparseW16A16Config()" + + @classmethod + def get_name(cls) -> str: + return "semi_structured_sparse_w16a16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # TODO: Update after checks on more GPUs + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["sparsity_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": + return cls() + + def get_linear_method(self) -> "SemiStructuredSparseW16A16LinearMethod": + return SemiStructuredSparseW16A16LinearMethod(self) + + +class SemiStructuredSparseW16A16LinearMethod(LinearMethodBase): + """Linear method for Semi Structured Sparse W16A16. + + Args: + sparsity_config: The sparse config. + """ + + def __init__(self, sparsity_config: SemiStructuredSparseW16A16Config): + self.sparsity_config = sparsity_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + weight = SparseParameter( + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + ) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + return {"weight": weight} + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + sparse_weight = weights["weight"] + + # Uncompress to dense + dense_weight = sparse_weight.to_dense() + + # # Uncomment to verify sparsity + # density = torch.count_nonzero( + # dense_weight).item() / dense_weight.numel() + # print(f"sparsity = {1.0 - density}") + + # Standard matrix multiply + if bias is not None: + output = F.linear(x, dense_weight, bias) + else: + output = F.linear(x, dense_weight) + + return output From 2856b91c61169356e9737a078f19306627595ed7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 3 Feb 2024 06:24:22 -0500 Subject: [PATCH 05/21] restructuring sparsity --- ...ffline_inference_semi_structured_sparse.py | 2 +- .../layers/parameters/sparsity.py | 6 +- .../layers/sparsity/base_config.py | 9 ++- .../sparsity/semi_structured_sparse_w16a16.py | 68 ++++--------------- .../layers/sparsity/sparse_w16a16.py | 65 +++--------------- .../sparsity/sparse_w16a16_linear_method.py | 66 ++++++++++++++++++ 6 files changed, 99 insertions(+), 117 deletions(-) create mode 100644 vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index 87757e312f7d5..35063feed5bab 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -5,7 +5,7 @@ model = LLM( "nm-testing/zephyr-50sparse-24", - sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense + sparsity="sparse_w16a16", # semi_structured_sparse_w16a16 # If left off, model will be loaded as dense enforce_eager=True, # Does not work with cudagraphs yet dtype="float16", tensor_parallel_size=1, diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index 37ddd05d89636..0e6280e67605c 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -1,6 +1,7 @@ import torch -from magic_wand import SparseTensor, SparseBitmaskStorageFormat +from typing import Type +from magic_wand import SparseTensor, CompressedStorageFormat class SparseParameter(SparseTensor): @@ -10,6 +11,7 @@ def __new__( cls, shape: torch.Size, dtype: torch.dtype, + storage_format_cls: Type[CompressedStorageFormat] ): assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" @@ -17,7 +19,7 @@ def __new__( size=shape, dtype=dtype, requires_grad=False) - self.storage_format_cls = SparseBitmaskStorageFormat + self.storage_format_cls = storage_format_cls self.compressed_data = None self.dense_data = None self._is_param = True diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index aa09fb623bc00..cfa6e93227476 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -2,12 +2,19 @@ from typing import Any, Dict, List import torch +from typing import Type from vllm.model_executor.layers.linear import LinearMethodBase - +from magic_wand import CompressedStorageFormat class SparsityConfig(ABC): """Base class for sparsity configs.""" + # storage_format_cls: Type[CompressedStorageFormat] = CompressedStorageFormat + + @abstractmethod + def get_storage_format_cls(self) -> Type[CompressedStorageFormat]: + """Sparse representation format""" + raise NotImplementedError @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index 1a28fdb143cad..7f436873d3f4c 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -1,12 +1,17 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type +from magic_wand import CompressedStorageFormat import torch import torch.nn.functional as F +from torch.nn import Parameter +from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.parameters import SparseParameter +from .sparse_w16a16_linear_method import SparseW16A16LinearMethod + class SemiStructuredSparseW16A16Config(SparsityConfig): """Config class for SemiStructuredSparseW16A16. @@ -19,6 +24,10 @@ def __init__(self) -> None: def __repr__(self) -> str: return "SemiStructuredSparseW16A16Config()" + @classmethod + def get_storage_format_cls(cls) -> Type: + return super().get_storage_format_cls() + @classmethod def get_name(cls) -> str: return "semi_structured_sparse_w16a16" @@ -40,58 +49,5 @@ def get_config_filenames(cls) -> List[str]: def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": return cls() - def get_linear_method(self) -> "SemiStructuredSparseW16A16LinearMethod": - return SemiStructuredSparseW16A16LinearMethod(self) - - -class SemiStructuredSparseW16A16LinearMethod(LinearMethodBase): - """Linear method for Semi Structured Sparse W16A16. - - Args: - sparsity_config: The sparse config. - """ - - def __init__(self, sparsity_config: SemiStructuredSparseW16A16Config): - self.sparsity_config = sparsity_config - - def create_weights( - self, - input_size_per_partition: int, - output_size_per_partition: int, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - ) - - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - - return {"weight": weight} - - def apply_weights( - self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - sparse_weight = weights["weight"] - - # Uncompress to dense - dense_weight = sparse_weight.to_dense() - - # # Uncomment to verify sparsity - # density = torch.count_nonzero( - # dense_weight).item() / dense_weight.numel() - # print(f"sparsity = {1.0 - density}") - - # Standard matrix multiply - if bias is not None: - output = F.linear(x, dense_weight, bias) - else: - output = F.linear(x, dense_weight) - - return output + def get_linear_method(self) -> "SparseW16A16LinearMethod": + return SparseW16A16LinearMethod(self) \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 771fae9b8ff45..31af2f38f89e2 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Type import torch import torch.nn.functional as F -from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter +from .sparse_w16a16_linear_method import SparseW16A16LinearMethod +from magic_wand import CompressedStorageFormat,SparseBitmaskStorageFormat class SparseW16A16Config(SparsityConfig): """Config class for SparseW16A16. @@ -21,6 +21,10 @@ def __init__(self) -> None: def __repr__(self) -> str: return "SparseW16A16Config()" + @classmethod + def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: + return SparseBitmaskStorageFormat + @classmethod def get_name(cls) -> str: return "sparse_w16a16" @@ -43,57 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config": return cls() def get_linear_method(self) -> "SparseW16A16LinearMethod": - return SparseW16A16LinearMethod(self) - - -class SparseW16A16LinearMethod(LinearMethodBase): - """Linear method for Sparse W16A16. - - Args: - sparsity_config: The sparse config. - """ - - def __init__(self, sparsity_config: SparseW16A16Config): - self.sparsity_config = sparsity_config - - def create_weights( - self, - input_size_per_partition: int, - output_size_per_partition: int, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - ) - - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - - return {"weight": weight} - - def apply_weights( - self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - sparse_weight = weights["weight"] - - # Uncompress to dense - dense_weight = sparse_weight.to_dense() - - # # Uncomment to verify sparsity - # density = torch.count_nonzero( - # dense_weight).item() / dense_weight.numel() - # print(f"sparsity = {1.0 - density}") - - # Standard matrix multiply - if bias is not None: - output = F.linear(x, dense_weight, bias) - else: - output = F.linear(x, dense_weight) - - return output + return SparseW16A16LinearMethod(self,self.get_storage_format_cls()) \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py new file mode 100644 index 0000000000000..3abef515131d8 --- /dev/null +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, Optional, Type + +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from vllm.model_executor.layers.parameters import SparseParameter +from magic_wand import ( + CompressedStorageFormat +) + +class SparseW16A16LinearMethod(LinearMethodBase): + """Linear method for Sparse W16A16. + + Args: + sparsity_config: The sparse config. + """ + storage_format_cls: Type[CompressedStorageFormat] = None + + def __init__(self, sparsity_config: SparsityConfig, storage_format_cls: Type[CompressedStorageFormat]): + self.sparsity_config = sparsity_config + self.storage_format_cls = storage_format_cls + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype + ) -> Dict[str, Any]: + weight = SparseParameter( + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + storage_format_cls=self.storage_format_cls + ) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + return {"weight": weight} + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + sparse_weight = weights["weight"] + + # Uncompress to dense + dense_weight = sparse_weight.to_dense() + + # # Uncomment to verify sparsity + # density = torch.count_nonzero( + # dense_weight).item() / dense_weight.numel() + # print(f"sparsity = {1.0 - density}") + + # Standard matrix multiply + if bias is not None: + output = F.linear(x, dense_weight, bias) + else: + output = F.linear(x, dense_weight) + + return output \ No newline at end of file From 708fe1b8f2b58b9dea10d6bf2af7872e6fceb7d5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 4 Feb 2024 07:00:28 -0500 Subject: [PATCH 06/21] difficulty creating sparse parameter class --- ...ffline_inference_semi_structured_sparse.py | 2 +- vllm/config.py | 2 +- .../layers/parameters/__init__.py | 2 +- .../layers/parameters/sparsity.py | 88 ++++++++++++++++++- .../layers/sparsity/__init__.py | 2 + .../sparsity/semi_structured_sparse_w16a16.py | 22 ++--- .../layers/sparsity/sparse_w16a16.py | 5 +- .../sparsity/sparse_w16a16_linear_method.py | 63 ++++++++----- 8 files changed, 146 insertions(+), 40 deletions(-) diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index 35063feed5bab..f17ee677e0d1d 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -5,7 +5,7 @@ model = LLM( "nm-testing/zephyr-50sparse-24", - sparsity="sparse_w16a16", # semi_structured_sparse_w16a16 # If left off, model will be loaded as dense + sparsity="semi_structured_sparse_w16a16", # sparse_w16a16 # If left off, model will be loaded as dense enforce_eager=True, # Does not work with cudagraphs yet dtype="float16", tensor_parallel_size=1, diff --git a/vllm/config.py b/vllm/config.py index d735819c0c2b1..4b82e2c18f78d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -148,7 +148,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_sparsity(self) -> None: - supported_sparsity = ["sparse_w16a16"] + supported_sparsity = ["sparse_w16a16","semi_structured_sparse_w16a16"] if self.quantization is not None: raise ValueError("Both sparsity and quantization detected. Only " diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 2d41190087a0d..73efb4a2606d6 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,5 +1,5 @@ import torch -from vllm.model_executor.layers.parameters.sparsity import SparseParameter +from vllm.model_executor.layers.parameters.sparsity import SparseParameter, SemiStructuredSparseParameter def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index 0e6280e67605c..bb70b29ed18ab 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -1,7 +1,8 @@ import torch +from torch.sparse import SparseSemiStructuredTensor from typing import Type -from magic_wand import SparseTensor, CompressedStorageFormat +from magic_wand import SparseTensor, CompressedStorageFormat, SparseSemiStructuredStorageFormat class SparseParameter(SparseTensor): @@ -15,6 +16,69 @@ def __new__( ): assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" + + self = torch.Tensor._make_wrapper_subclass(cls, + size=shape, + dtype=dtype, + requires_grad=False) + self.storage_format_cls = storage_format_cls + self.compressed_data = None + self.dense_data = None + self._is_param = True + + return self + + def has_compressed_data(self) -> bool: + return (self.compressed_data is not None) + + def get_dense_data(self) -> torch.Tensor: + if self.dense_data is not None: + raise ValueError( + "Called get_data_dense() but dense_data already exists.") + self.dense_data = self._unpack() + return self.dense_data + + def _unpack(self) -> torch.Tensor: + if self.has_compressed_data(): + return self.compressed_data.decompress() + else: + return torch.empty(size=self.shape, + dtype=self.dtype, + device="cuda") + + @classmethod + def _copy(cls, arg0, arg1): + assert arg0.shape == arg1.shape + + if arg0.has_compressed_data(): + arg0.compressed_data.copy_(arg1) + else: + arg0.compressed_data = arg0.storage_format_cls.compress(arg1) + + return arg0 + + def copy_(self, src, non_blocking=False): + return SparseParameter._copy(self, src) + + def pack(self) -> None: + if self.dense_data is None: + raise ValueError("Called pack() but dense_data does not exist.") + self.copy_(self.dense_data) + self.dense_data = None + + +class SemiStructuredSparseParameter(torch.Tensor): + + @staticmethod + def __new__( + cls, + shape: torch.Size, + dtype: torch.dtype, + storage_format_cls: Type[CompressedStorageFormat] + ): + assert torch.__version__ > (1, + 10), "SparseTensor requires PyTorch 1.11+" + assert storage_format_cls == SparseSemiStructuredStorageFormat self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, @@ -26,6 +90,9 @@ def __new__( return self + def has_compressed_data(self) -> bool: + return (self.compressed_data is not None) + def get_dense_data(self) -> torch.Tensor: if self.dense_data is not None: raise ValueError( @@ -41,8 +108,27 @@ def _unpack(self) -> torch.Tensor: dtype=self.dtype, device="cuda") + @classmethod + def _copy(cls, arg0, arg1): + assert arg0.shape == arg1.shape + + if arg0.has_compressed_data(): + arg0.compressed_data.copy_(arg1) + else: + arg0.compressed_data = arg0.storage_format_cls.compress(arg1) + + return arg0 + + def copy_(self, src, non_blocking=False): + return self.__class__._copy(self, src) + def pack(self) -> None: if self.dense_data is None: raise ValueError("Called pack() but dense_data does not exist.") self.copy_(self.dense_data) self.dense_data = None + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + + # Forward the call to the original torch function with the same arguments + return super().__torch_dispatch__(func,types,args,kwargs) diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py index 411d1ff642266..236620d130b8a 100644 --- a/vllm/model_executor/layers/sparsity/__init__.py +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -2,9 +2,11 @@ from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config +from vllm.model_executor.layers.sparsity.semi_structured_sparse_w16a16 import SemiStructuredSparseW16A16Config _SPARSITY_CONFIG_REGISTRY = { "sparse_w16a16": SparseW16A16Config, + "semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config } diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index 7f436873d3f4c..2963ec3de95bf 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -1,17 +1,13 @@ -from typing import Any, Dict, List, Optional, Type -from magic_wand import CompressedStorageFormat - import torch -import torch.nn.functional as F -from torch.nn import Parameter -from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor -from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from typing import Any, Dict, List, Type +from magic_wand import CompressedStorageFormat from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter - from .sparse_w16a16_linear_method import SparseW16A16LinearMethod - +from magic_wand import ( + CompressedStorageFormat, + SparseSemiStructuredStorageFormat +) class SemiStructuredSparseW16A16Config(SparsityConfig): """Config class for SemiStructuredSparseW16A16. @@ -25,8 +21,8 @@ def __repr__(self) -> str: return "SemiStructuredSparseW16A16Config()" @classmethod - def get_storage_format_cls(cls) -> Type: - return super().get_storage_format_cls() + def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: + return SparseSemiStructuredStorageFormat @classmethod def get_name(cls) -> str: @@ -50,4 +46,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Confi return cls() def get_linear_method(self) -> "SparseW16A16LinearMethod": - return SparseW16A16LinearMethod(self) \ No newline at end of file + return SparseW16A16LinearMethod(self,self.get_storage_format_cls()) \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 31af2f38f89e2..fb80b894caefa 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -6,7 +6,10 @@ from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from .sparse_w16a16_linear_method import SparseW16A16LinearMethod -from magic_wand import CompressedStorageFormat,SparseBitmaskStorageFormat +from magic_wand import ( + CompressedStorageFormat, + SparseBitmaskStorageFormat +) class SparseW16A16Config(SparsityConfig): """Config class for SparseW16A16. diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 3abef515131d8..53880bcc752a2 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -5,9 +5,10 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter +from vllm.model_executor.layers.parameters import SparseParameter, SemiStructuredSparseParameter from magic_wand import ( - CompressedStorageFormat + CompressedStorageFormat, + SparseSemiStructuredStorageFormat ) class SparseW16A16LinearMethod(LinearMethodBase): @@ -30,16 +31,28 @@ def create_weights( output_size: int, params_dtype: torch.dtype ) -> Dict[str, Any]: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - storage_format_cls=self.storage_format_cls - ) + if self.storage_format_cls == SparseSemiStructuredStorageFormat: + weight = SemiStructuredSparseParameter( + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + storage_format_cls=self.storage_format_cls + ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + return {"weight": weight} + else: + weight = SparseParameter( + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + storage_format_cls=self.storage_format_cls + ) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + return {"weight": weight} def apply_weights( self, @@ -49,18 +62,24 @@ def apply_weights( ) -> torch.Tensor: sparse_weight = weights["weight"] - # Uncompress to dense - dense_weight = sparse_weight.to_dense() + if self.storage_format_cls == SparseSemiStructuredStorageFormat: + if bias is not None: + output = F.linear(x, sparse_weight.compressed_data.encapsulated_torch_sparse_tensor, bias) + else: + output = F.linear(x, sparse_weight.compressed_data.encapsulated_torch_sparse_tensor) - # # Uncomment to verify sparsity - # density = torch.count_nonzero( - # dense_weight).item() / dense_weight.numel() - # print(f"sparsity = {1.0 - density}") - - # Standard matrix multiply - if bias is not None: - output = F.linear(x, dense_weight, bias) + return output else: - output = F.linear(x, dense_weight) + # # Uncomment to verify sparsity + # density = torch.count_nonzero( + # dense_weight).item() / dense_weight.numel() + # print(f"sparsity = {1.0 - density}") + + # Standard matrix multiply + # Uncompress to dense + if bias is not None: + output = F.linear(x, sparse_weight.to_dense(), bias) + else: + output = F.linear(x, sparse_weight.to_dense()) - return output \ No newline at end of file + return output \ No newline at end of file From 40a8afbaf6a74f3ad1849f68b4edc50d2b88a45f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 4 Feb 2024 15:10:07 -0500 Subject: [PATCH 07/21] first successful run with 2:4 sparse model; compat with magic_wand branch safe_expose_semi_structured_sparse_tensor --- .../layers/parameters/__init__.py | 2 +- .../layers/parameters/sparsity.py | 139 +++++++++--------- .../sparsity/sparse_w16a16_linear_method.py | 35 ++--- 3 files changed, 85 insertions(+), 91 deletions(-) diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 73efb4a2606d6..2d41190087a0d 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,5 +1,5 @@ import torch -from vllm.model_executor.layers.parameters.sparsity import SparseParameter, SemiStructuredSparseParameter +from vllm.model_executor.layers.parameters.sparsity import SparseParameter def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index bb70b29ed18ab..b0f78396652d1 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -2,7 +2,12 @@ from torch.sparse import SparseSemiStructuredTensor from typing import Type -from magic_wand import SparseTensor, CompressedStorageFormat, SparseSemiStructuredStorageFormat +from magic_wand import ( + SparseTensor, + CompressedStorageFormat, + SparseBitmaskStorageFormat, + SparseSemiStructuredStorageFormat +) class SparseParameter(SparseTensor): @@ -12,7 +17,7 @@ def __new__( cls, shape: torch.Size, dtype: torch.dtype, - storage_format_cls: Type[CompressedStorageFormat] + storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat ): assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" @@ -67,68 +72,68 @@ def pack(self) -> None: self.dense_data = None -class SemiStructuredSparseParameter(torch.Tensor): - - @staticmethod - def __new__( - cls, - shape: torch.Size, - dtype: torch.dtype, - storage_format_cls: Type[CompressedStorageFormat] - ): - assert torch.__version__ > (1, - 10), "SparseTensor requires PyTorch 1.11+" - assert storage_format_cls == SparseSemiStructuredStorageFormat - self = torch.Tensor._make_wrapper_subclass(cls, - size=shape, - dtype=dtype, - requires_grad=False) - self.storage_format_cls = storage_format_cls - self.compressed_data = None - self.dense_data = None - self._is_param = True - - return self - - def has_compressed_data(self) -> bool: - return (self.compressed_data is not None) - - def get_dense_data(self) -> torch.Tensor: - if self.dense_data is not None: - raise ValueError( - "Called get_data_dense() but dense_data already exists.") - self.dense_data = self._unpack() - return self.dense_data - - def _unpack(self) -> torch.Tensor: - if self.has_compressed_data(): - return self.compressed_data.decompress() - else: - return torch.empty(size=self.shape, - dtype=self.dtype, - device="cuda") - - @classmethod - def _copy(cls, arg0, arg1): - assert arg0.shape == arg1.shape - - if arg0.has_compressed_data(): - arg0.compressed_data.copy_(arg1) - else: - arg0.compressed_data = arg0.storage_format_cls.compress(arg1) - - return arg0 - - def copy_(self, src, non_blocking=False): - return self.__class__._copy(self, src) - - def pack(self) -> None: - if self.dense_data is None: - raise ValueError("Called pack() but dense_data does not exist.") - self.copy_(self.dense_data) - self.dense_data = None - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - - # Forward the call to the original torch function with the same arguments - return super().__torch_dispatch__(func,types,args,kwargs) +# class SemiStructuredSparseParameter(torch.Tensor): + +# @staticmethod +# def __new__( +# cls, +# shape: torch.Size, +# dtype: torch.dtype, +# storage_format_cls: Type[CompressedStorageFormat] +# ): +# assert torch.__version__ > (1, +# 10), "SparseTensor requires PyTorch 1.11+" +# assert storage_format_cls == SparseSemiStructuredStorageFormat +# self = torch.Tensor._make_wrapper_subclass(cls, +# size=shape, +# dtype=dtype, +# requires_grad=False) +# self.storage_format_cls = storage_format_cls +# self.compressed_data = None +# self.dense_data = None +# self._is_param = True + +# return self + +# def has_compressed_data(self) -> bool: +# return (self.compressed_data is not None) + +# def get_dense_data(self) -> torch.Tensor: +# if self.dense_data is not None: +# raise ValueError( +# "Called get_data_dense() but dense_data already exists.") +# self.dense_data = self._unpack() +# return self.dense_data + +# def _unpack(self) -> torch.Tensor: +# if self.has_compressed_data(): +# return self.compressed_data.decompress() +# else: +# return torch.empty(size=self.shape, +# dtype=self.dtype, +# device="cuda") + +# @classmethod +# def _copy(cls, arg0, arg1): +# assert arg0.shape == arg1.shape + +# if arg0.has_compressed_data(): +# arg0.compressed_data.copy_(arg1) +# else: +# arg0.compressed_data = arg0.storage_format_cls.compress(arg1) + +# return arg0 + +# def copy_(self, src, non_blocking=False): +# return self.__class__._copy(self, src) + +# def pack(self) -> None: +# if self.dense_data is None: +# raise ValueError("Called pack() but dense_data does not exist.") +# self.copy_(self.dense_data) +# self.dense_data = None + +# def __torch_dispatch__(self, func, types, args=(), kwargs=None): + +# # Forward the call to the original torch function with the same arguments +# return super().__torch_dispatch__(func,types,args,kwargs) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 53880bcc752a2..ece461cc2e5aa 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -5,7 +5,7 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter, SemiStructuredSparseParameter +from vllm.model_executor.layers.parameters import SparseParameter from magic_wand import ( CompressedStorageFormat, SparseSemiStructuredStorageFormat @@ -31,28 +31,16 @@ def create_weights( output_size: int, params_dtype: torch.dtype ) -> Dict[str, Any]: - if self.storage_format_cls == SparseSemiStructuredStorageFormat: - weight = SemiStructuredSparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - storage_format_cls=self.storage_format_cls - ) - - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - - return {"weight": weight} - else: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - storage_format_cls=self.storage_format_cls - ) + weight = SparseParameter( + shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + storage_format_cls=self.storage_format_cls + ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + return {"weight": weight} def apply_weights( self, @@ -64,12 +52,13 @@ def apply_weights( if self.storage_format_cls == SparseSemiStructuredStorageFormat: if bias is not None: - output = F.linear(x, sparse_weight.compressed_data.encapsulated_torch_sparse_tensor, bias) + output = F.linear(x, sparse_weight, bias) else: - output = F.linear(x, sparse_weight.compressed_data.encapsulated_torch_sparse_tensor) + output = F.linear(x, sparse_weight, bias=None) return output else: + assert(False) # # Uncomment to verify sparsity # density = torch.count_nonzero( # dense_weight).item() / dense_weight.numel() From a344b60a6c3ac0184227ca4ef2da24e6480028bd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 5 Feb 2024 15:15:39 -0500 Subject: [PATCH 08/21] woops uncommenting assert statement --- .../layers/sparsity/sparse_w16a16_linear_method.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index ece461cc2e5aa..d9652b6ab643a 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -58,7 +58,6 @@ def apply_weights( return output else: - assert(False) # # Uncomment to verify sparsity # density = torch.count_nonzero( # dense_weight).item() / dense_weight.numel() From 7a2a7ed3b5e05b90ee027c98a5357792509109fd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 8 Feb 2024 13:51:05 -0500 Subject: [PATCH 09/21] fixes --- .gitignore | 3 - magic_wand | 1 + .../layers/parameters/sparsity.py | 69 +------------------ .../layers/sparsity/base_config.py | 1 - .../sparsity/semi_structured_sparse_w16a16.py | 1 - .../sparsity/sparse_w16a16_linear_method.py | 4 -- 6 files changed, 2 insertions(+), 77 deletions(-) create mode 160000 magic_wand diff --git a/.gitignore b/.gitignore index 8e46368ad7df4..b5195629e5cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Dependency repos -magic_wand - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/magic_wand b/magic_wand new file mode 160000 index 0000000000000..5b74fb0c043c8 --- /dev/null +++ b/magic_wand @@ -0,0 +1 @@ +Subproject commit 5b74fb0c043c8d4005bca3f46601a4941e00c5aa diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index b0f78396652d1..53409c799d068 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -69,71 +69,4 @@ def pack(self) -> None: if self.dense_data is None: raise ValueError("Called pack() but dense_data does not exist.") self.copy_(self.dense_data) - self.dense_data = None - - -# class SemiStructuredSparseParameter(torch.Tensor): - -# @staticmethod -# def __new__( -# cls, -# shape: torch.Size, -# dtype: torch.dtype, -# storage_format_cls: Type[CompressedStorageFormat] -# ): -# assert torch.__version__ > (1, -# 10), "SparseTensor requires PyTorch 1.11+" -# assert storage_format_cls == SparseSemiStructuredStorageFormat -# self = torch.Tensor._make_wrapper_subclass(cls, -# size=shape, -# dtype=dtype, -# requires_grad=False) -# self.storage_format_cls = storage_format_cls -# self.compressed_data = None -# self.dense_data = None -# self._is_param = True - -# return self - -# def has_compressed_data(self) -> bool: -# return (self.compressed_data is not None) - -# def get_dense_data(self) -> torch.Tensor: -# if self.dense_data is not None: -# raise ValueError( -# "Called get_data_dense() but dense_data already exists.") -# self.dense_data = self._unpack() -# return self.dense_data - -# def _unpack(self) -> torch.Tensor: -# if self.has_compressed_data(): -# return self.compressed_data.decompress() -# else: -# return torch.empty(size=self.shape, -# dtype=self.dtype, -# device="cuda") - -# @classmethod -# def _copy(cls, arg0, arg1): -# assert arg0.shape == arg1.shape - -# if arg0.has_compressed_data(): -# arg0.compressed_data.copy_(arg1) -# else: -# arg0.compressed_data = arg0.storage_format_cls.compress(arg1) - -# return arg0 - -# def copy_(self, src, non_blocking=False): -# return self.__class__._copy(self, src) - -# def pack(self) -> None: -# if self.dense_data is None: -# raise ValueError("Called pack() but dense_data does not exist.") -# self.copy_(self.dense_data) -# self.dense_data = None - -# def __torch_dispatch__(self, func, types, args=(), kwargs=None): - -# # Forward the call to the original torch function with the same arguments -# return super().__torch_dispatch__(func,types,args,kwargs) + self.dense_data = None \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index cfa6e93227476..953ba301e8d5b 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -9,7 +9,6 @@ class SparsityConfig(ABC): """Base class for sparsity configs.""" - # storage_format_cls: Type[CompressedStorageFormat] = CompressedStorageFormat @abstractmethod def get_storage_format_cls(self) -> Type[CompressedStorageFormat]: diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index 2963ec3de95bf..e47d1e0f7ad90 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -14,7 +14,6 @@ class SemiStructuredSparseW16A16Config(SparsityConfig): """ def __init__(self) -> None: - # TODO: Add new configs here pass def __repr__(self) -> str: diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index d9652b6ab643a..e1ccf48dc5d5b 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -58,10 +58,6 @@ def apply_weights( return output else: - # # Uncomment to verify sparsity - # density = torch.count_nonzero( - # dense_weight).item() / dense_weight.numel() - # print(f"sparsity = {1.0 - density}") # Standard matrix multiply # Uncompress to dense From 0711a74062109dcced8f40e87a76166814af6dd0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 8 Feb 2024 20:12:04 -0500 Subject: [PATCH 10/21] bfloat16 --- examples/offline_inference_semi_structured_sparse.py | 6 +++--- .../layers/sparsity/semi_structured_sparse_w16a16.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index f17ee677e0d1d..01539f650a034 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -5,9 +5,9 @@ model = LLM( "nm-testing/zephyr-50sparse-24", - sparsity="semi_structured_sparse_w16a16", # sparse_w16a16 # If left off, model will be loaded as dense - enforce_eager=True, # Does not work with cudagraphs yet - dtype="float16", + sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense + enforce_eager=True, # Does not work with cudagraphs yet + dtype="float16", # bfloat16 tensor_parallel_size=1, max_model_len=1024 ) diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index e47d1e0f7ad90..4b03ee9dca685 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -29,7 +29,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] + return [torch.half,torch.bfloat16] @classmethod def get_min_capability(cls) -> int: From fc85cac2e8edcaf80e036fda5532206e781ac492 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 8 Feb 2024 20:34:14 -0500 Subject: [PATCH 11/21] hopefully removed magic_wand submodule --- magic_wand | 1 - 1 file changed, 1 deletion(-) delete mode 160000 magic_wand diff --git a/magic_wand b/magic_wand deleted file mode 160000 index 5b74fb0c043c8..0000000000000 --- a/magic_wand +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5b74fb0c043c8d4005bca3f46601a4941e00c5aa From d7b2f41a6cf0d88989f3bf19f2a8b284e9f0d6ed Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 12 Feb 2024 11:56:21 -0500 Subject: [PATCH 12/21] wip bench --- benchmarks/simple_sparse_benchmark.py | 32 +++++++++++++++++++ ...ffline_inference_semi_structured_sparse.py | 6 ++-- .../sparsity/sparse_w16a16_linear_method.py | 13 ++------ 3 files changed, 37 insertions(+), 14 deletions(-) create mode 100644 benchmarks/simple_sparse_benchmark.py diff --git a/benchmarks/simple_sparse_benchmark.py b/benchmarks/simple_sparse_benchmark.py new file mode 100644 index 0000000000000..e4e3cb880ae85 --- /dev/null +++ b/benchmarks/simple_sparse_benchmark.py @@ -0,0 +1,32 @@ +from vllm import LLM, SamplingParams +import os,copy,time +#os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 + + +model = LLM( + "nm-testing/zephyr-50sparse-24", + sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense + enforce_eager=True, # Does not work with cudagraphs yet + dtype="bfloat16", # bfloat16 + tensor_parallel_size=2, + max_model_len=4096*2 +) + +num_prompts=64 +input_len=3072*2 +prompts=[copy.deepcopy(prompt) for prompt in (["Hi"*input_len]*num_prompts)] + +sampling_params = SamplingParams(max_tokens=100,temperature=0) + +start_time = time.time() + +outputs = model.generate(prompts, sampling_params=sampling_params) + +end_time = time.time() # Capture end time +duration = end_time - start_time # Calculate duration + +print(f"Elapsed time: {duration} seconds") + +#print(outputs[0]) +#print(outputs) +#print(outputs[0].outputs[0].text) \ No newline at end of file diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index 01539f650a034..19f60cff3c119 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -1,14 +1,14 @@ from vllm import LLM, SamplingParams import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 +#os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 model = LLM( "nm-testing/zephyr-50sparse-24", - sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense + #sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense enforce_eager=True, # Does not work with cudagraphs yet dtype="float16", # bfloat16 - tensor_parallel_size=1, + tensor_parallel_size=2, max_model_len=1024 ) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index e1ccf48dc5d5b..07b1e0bb7c03c 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -51,19 +51,10 @@ def apply_weights( sparse_weight = weights["weight"] if self.storage_format_cls == SparseSemiStructuredStorageFormat: - if bias is not None: - output = F.linear(x, sparse_weight, bias) - else: - output = F.linear(x, sparse_weight, bias=None) - + output = F.linear(x, sparse_weight, bias) return output else: - # Standard matrix multiply # Uncompress to dense - if bias is not None: - output = F.linear(x, sparse_weight.to_dense(), bias) - else: - output = F.linear(x, sparse_weight.to_dense()) - + output = F.linear(x, sparse_weight.to_dense(), bias) return output \ No newline at end of file From ef647115ea7e5e14ea5c4500f53632fe08a2e0ab Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 20:59:02 -0500 Subject: [PATCH 13/21] initial integration --- benchmarks/simple_sparse_benchmark.py | 27 +++---- ...ffline_inference_semi_structured_sparse.py | 13 ++-- vllm/config.py | 2 +- vllm/model_executor/layers/linear.py | 58 ++++++++------- .../layers/parameters/__init__.py | 10 +-- .../layers/parameters/lazy_compressed.py | 74 +++++++++++++++++++ .../layers/parameters/sparsity.py | 72 ------------------ .../layers/sparsity/base_config.py | 1 + .../sparsity/semi_structured_sparse_w16a16.py | 15 ++-- .../layers/sparsity/sparse_w16a16.py | 11 +-- .../sparsity/sparse_w16a16_linear_method.py | 54 ++++++++------ vllm/model_executor/weight_utils.py | 7 +- 12 files changed, 171 insertions(+), 173 deletions(-) create mode 100644 vllm/model_executor/layers/parameters/lazy_compressed.py delete mode 100644 vllm/model_executor/layers/parameters/sparsity.py diff --git a/benchmarks/simple_sparse_benchmark.py b/benchmarks/simple_sparse_benchmark.py index e4e3cb880ae85..7236cda5a12ff 100644 --- a/benchmarks/simple_sparse_benchmark.py +++ b/benchmarks/simple_sparse_benchmark.py @@ -1,22 +1,23 @@ from vllm import LLM, SamplingParams -import os,copy,time +import copy +import time #os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 - model = LLM( - "nm-testing/zephyr-50sparse-24", - sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense - enforce_eager=True, # Does not work with cudagraphs yet - dtype="bfloat16", # bfloat16 + "nm-testing/zephyr-50sparse-24", + #sparsity="sparse_w16a16", # If left off, model will be loaded as dense + enforce_eager=True, # Does not work with cudagraphs yet + dtype="float16", # bfloat16 tensor_parallel_size=2, - max_model_len=4096*2 -) + max_model_len=1024) -num_prompts=64 -input_len=3072*2 -prompts=[copy.deepcopy(prompt) for prompt in (["Hi"*input_len]*num_prompts)] +num_prompts = 64 +input_len = 3072 * 2 +prompts = [ + copy.deepcopy(prompt) for prompt in (["Hi im a prompt"] * num_prompts) +] -sampling_params = SamplingParams(max_tokens=100,temperature=0) +sampling_params = SamplingParams(max_tokens=100, temperature=0) start_time = time.time() @@ -29,4 +30,4 @@ #print(outputs[0]) #print(outputs) -#print(outputs[0].outputs[0].text) \ No newline at end of file +#print(outputs[0].outputs[0].text) diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py index 19f60cff3c119..9b28e6e2655c4 100644 --- a/examples/offline_inference_semi_structured_sparse.py +++ b/examples/offline_inference_semi_structured_sparse.py @@ -1,17 +1,14 @@ from vllm import LLM, SamplingParams -import os #os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 - model = LLM( - "nm-testing/zephyr-50sparse-24", + "nm-testing/zephyr-50sparse-24", #sparsity="semi_structured_sparse_w16a16", # If left off, model will be loaded as dense - enforce_eager=True, # Does not work with cudagraphs yet - dtype="float16", # bfloat16 + enforce_eager=True, # Does not work with cudagraphs yet + dtype="float16", # bfloat16 tensor_parallel_size=2, - max_model_len=1024 -) + max_model_len=1024) sampling_params = SamplingParams(max_tokens=100, temperature=0) outputs = model.generate("Hello my name is", sampling_params=sampling_params) -print(outputs[0].outputs[0].text) \ No newline at end of file +print(outputs[0].outputs[0].text) diff --git a/vllm/config.py b/vllm/config.py index 4b82e2c18f78d..a86fbc3cfde84 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -148,7 +148,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_sparsity(self) -> None: - supported_sparsity = ["sparse_w16a16","semi_structured_sparse_w16a16"] + supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"] if self.quantization is not None: raise ValueError("Both sparsity and quantization detected. Only " diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d09db721d712b..9af9b0c5f2f62 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,7 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger -from vllm.model_executor.layers.parameters import SparseParameter, get_param_data +from vllm.model_executor.layers.parameters import LazyCompressedParameter logger = init_logger(__name__) @@ -196,7 +196,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) - param_data = get_param_data(param) + param_data = param.data if output_dim is not None: shard_size = param_data.shape[output_dim] @@ -206,9 +206,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + # If LazyCompressedParameter, compress the data. + if isinstance(param, LazyCompressedParameter): + param.compress() def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -257,6 +257,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): self.output_sizes = output_sizes + self.loaded_shards = set() tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, @@ -266,14 +267,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): - param_data = get_param_data(param) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: - if isinstance(param, SparseParameter): - raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter" - ) - # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -320,12 +316,17 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + self.loaded_shards.add(loaded_shard_id) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If Parameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + # This is super hacky for now but we basically want to only compress once all + # of the shards are loaded, right now we just check if the number of shards + # loaded matches the number of outputs expected, assuming one shard per output + all_shards_loaded = (len(self.loaded_shards) == len(self.output_sizes)) + if all_shards_loaded and isinstance(param, LazyCompressedParameter): + param.compress() class QKVParallelLinear(ColumnParallelLinear): @@ -369,6 +370,7 @@ def __init__( if total_num_kv_heads is None: total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads + self.loaded_shards = set() # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() self.num_heads = divide(self.total_num_heads, tp_size) @@ -389,14 +391,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): - param_data = get_param_data(param) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: - if isinstance(param, SparseParameter): - raise NotImplementedError( - "Passing loaded_shard_id=None not yet supported for SparseParameter" - ) - # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -460,9 +457,16 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + self.loaded_shards.add(loaded_shard_id) + + # This is super hacky for now but we basically want to only compress once + # all of the shards are loaded, for the QKV matrix this means + # loading shards "q", "k" and "v" + all_shards_loaded = (self.loaded_shards == set(["q", "k", "v"])) + + # If LazyCompressedParameter, compress the data. + if all_shards_loaded and isinstance(param, LazyCompressedParameter): + param.compress() class RowParallelLinear(torch.nn.Module): @@ -546,7 +550,7 @@ def __init__( def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) - param_data = get_param_data(param) + param_data = param.data if input_dim is not None: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size @@ -555,9 +559,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If SparseParameter, repack dense data as sparse. - if isinstance(param, SparseParameter): - param.pack() + # If LazyCompressedParameter, compress the data. + if isinstance(param, LazyCompressedParameter): + param.compress() def forward(self, input_): # Set up backprop all-reduce. diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 2d41190087a0d..5271295b2234b 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,10 +1,2 @@ import torch -from vllm.model_executor.layers.parameters.sparsity import SparseParameter - - -def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: - """Gets parameter data in dense format.""" - if isinstance(param, SparseParameter): - return param.get_dense_data() - else: - return param.data +from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py new file mode 100644 index 0000000000000..1224e467c56da --- /dev/null +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -0,0 +1,74 @@ +import numpy +import torch +from torch.utils._pytree import tree_map, tree_flatten +from torch.sparse import SparseSemiStructuredTensor + +from typing import Type +from magic_wand import ( + SparseTensor, + CompressedStorageFormat, + SparseBitmaskStorageFormat, + SparseSemiStructuredStorageFormat +) + + +class LazyCompressedParameter(torch.Tensor): + + @staticmethod + def __new__( + cls, + uncompressed_data: torch.Tensor, + storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat, + compress_transposed: bool = False + ): + self = torch.Tensor._make_wrapper_subclass(cls, + size=uncompressed_data.shape, + dtype=uncompressed_data.dtype, + requires_grad=False) + self.storage_format_cls = storage_format_cls + self.compressed_data = None + self.uncompressed_data = uncompressed_data + self.compress_transposed = compress_transposed + self._is_param = True + + return self + + @property + def has_compressed_data(self) -> bool: + return (self.compressed_data is not None) + + @property + def has_uncompressed_data(self) -> bool: + return (self.uncompressed_data is not None) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + ret_storage_format_cls = None + + def unwrap(e): + nonlocal ret_storage_format_cls + if isinstance(e, LazyCompressedParameter): + assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls + ret_storage_format_cls = e.storage_format_cls + return e.uncompressed_data if isinstance(e, LazyCompressedParameter) else e + + rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + + def wrap(e): + if isinstance(e, torch.Tensor) and ret_storage_format_cls is not None: + return LazyCompressedParameter(e, storage_format_cls=ret_storage_format_cls) + return e + + rs = tree_map(wrap, rs) + return rs + + def compress(self) -> None: + density = torch.count_nonzero(self.uncompressed_data).item() / numpy.prod(self.shape) + print("density: ", density, torch.count_nonzero(self.uncompressed_data), numpy.prod(self.shape)) + + if self.uncompressed_data is None: + raise ValueError("Called compress() but uncompressed_data does not exist.") + self.compressed_data = self.storage_format_cls.compress( + self.uncompressed_data.t() if self.compress_transposed else self.uncompressed_data + ) + self.uncompressed_data = None \ No newline at end of file diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py deleted file mode 100644 index 53409c799d068..0000000000000 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from torch.sparse import SparseSemiStructuredTensor - -from typing import Type -from magic_wand import ( - SparseTensor, - CompressedStorageFormat, - SparseBitmaskStorageFormat, - SparseSemiStructuredStorageFormat -) - - -class SparseParameter(SparseTensor): - - @staticmethod - def __new__( - cls, - shape: torch.Size, - dtype: torch.dtype, - storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat - ): - assert torch.__version__ > (1, - 10), "SparseTensor requires PyTorch 1.11+" - - self = torch.Tensor._make_wrapper_subclass(cls, - size=shape, - dtype=dtype, - requires_grad=False) - self.storage_format_cls = storage_format_cls - self.compressed_data = None - self.dense_data = None - self._is_param = True - - return self - - def has_compressed_data(self) -> bool: - return (self.compressed_data is not None) - - def get_dense_data(self) -> torch.Tensor: - if self.dense_data is not None: - raise ValueError( - "Called get_data_dense() but dense_data already exists.") - self.dense_data = self._unpack() - return self.dense_data - - def _unpack(self) -> torch.Tensor: - if self.has_compressed_data(): - return self.compressed_data.decompress() - else: - return torch.empty(size=self.shape, - dtype=self.dtype, - device="cuda") - - @classmethod - def _copy(cls, arg0, arg1): - assert arg0.shape == arg1.shape - - if arg0.has_compressed_data(): - arg0.compressed_data.copy_(arg1) - else: - arg0.compressed_data = arg0.storage_format_cls.compress(arg1) - - return arg0 - - def copy_(self, src, non_blocking=False): - return SparseParameter._copy(self, src) - - def pack(self) -> None: - if self.dense_data is None: - raise ValueError("Called pack() but dense_data does not exist.") - self.copy_(self.dense_data) - self.dense_data = None \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index 953ba301e8d5b..fe46b55cbf39f 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import LinearMethodBase from magic_wand import CompressedStorageFormat + class SparsityConfig(ABC): """Base class for sparsity configs.""" diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index 4b03ee9dca685..78a67cb1f4483 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -1,13 +1,11 @@ import torch from typing import Any, Dict, List, Type -from magic_wand import CompressedStorageFormat from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from .sparse_w16a16_linear_method import SparseW16A16LinearMethod -from magic_wand import ( - CompressedStorageFormat, - SparseSemiStructuredStorageFormat -) +from magic_wand import (CompressedStorageFormat, + SparseSemiStructuredStorageFormat) + class SemiStructuredSparseW16A16Config(SparsityConfig): """Config class for SemiStructuredSparseW16A16. @@ -29,7 +27,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half,torch.bfloat16] + return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -41,8 +39,9 @@ def get_config_filenames(cls) -> List[str]: return ["sparsity_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": + def from_config( + cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": return cls() def get_linear_method(self) -> "SparseW16A16LinearMethod": - return SparseW16A16LinearMethod(self,self.get_storage_format_cls()) \ No newline at end of file + return SparseW16A16LinearMethod(self, self.get_storage_format_cls()) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index fb80b894caefa..d3a93d9b1d945 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -1,15 +1,12 @@ from typing import Any, Dict, List, Type import torch -import torch.nn.functional as F from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from .sparse_w16a16_linear_method import SparseW16A16LinearMethod -from magic_wand import ( - CompressedStorageFormat, - SparseBitmaskStorageFormat -) +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) + class SparseW16A16Config(SparsityConfig): """Config class for SparseW16A16. @@ -26,7 +23,7 @@ def __repr__(self) -> str: @classmethod def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: - return SparseBitmaskStorageFormat + return SparseBEGemmStorageFormat @classmethod def get_name(cls) -> str: @@ -50,4 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config": return cls() def get_linear_method(self) -> "SparseW16A16LinearMethod": - return SparseW16A16LinearMethod(self,self.get_storage_format_cls()) \ No newline at end of file + return SparseW16A16LinearMethod(self, self.get_storage_format_cls()) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 07b1e0bb7c03c..b3e414c99cd80 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -5,11 +5,11 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter -from magic_wand import ( - CompressedStorageFormat, - SparseSemiStructuredStorageFormat -) +from vllm.model_executor.layers.parameters import LazyCompressedParameter +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat, + SparseSemiStructuredStorageFormat) +from magic_wand.ops import be_ds_gemm + class SparseW16A16LinearMethod(LinearMethodBase): """Linear method for Sparse W16A16. @@ -19,24 +19,23 @@ class SparseW16A16LinearMethod(LinearMethodBase): """ storage_format_cls: Type[CompressedStorageFormat] = None - def __init__(self, sparsity_config: SparsityConfig, storage_format_cls: Type[CompressedStorageFormat]): + def __init__(self, sparsity_config: SparsityConfig, + storage_format_cls: Type[CompressedStorageFormat]): self.sparsity_config = sparsity_config self.storage_format_cls = storage_format_cls - def create_weights( - self, - input_size_per_partition: int, - output_size_per_partition: int, - input_size: int, - output_size: int, - params_dtype: torch.dtype - ) -> Dict[str, Any]: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - storage_format_cls=self.storage_format_cls - ) + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + self.input_size_per_partition = input_size_per_partition + self.output_size_per_partition = output_size_per_partition + + weight = LazyCompressedParameter( + torch.empty((output_size_per_partition, input_size_per_partition), + dtype=params_dtype), + storage_format_cls=self.storage_format_cls, + compress_transposed=True) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -48,13 +47,20 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - sparse_weight = weights["weight"] + w: LazyCompressedParameter = weights["weight"] if self.storage_format_cls == SparseSemiStructuredStorageFormat: - output = F.linear(x, sparse_weight, bias) + output = F.linear(x, w, bias) return output + if self.storage_format_cls == SparseBEGemmStorageFormat: + assert bias is None + assert w.compress_transposed + out_shape = (x.shape[:-1] + (w.shape[0], )) + reshaped_x = x.reshape(-1, x.shape[-1]) + y = be_ds_gemm(reshaped_x, w.compressed_data) + return y.reshape(out_shape) else: # Standard matrix multiply # Uncompress to dense - output = F.linear(x, sparse_weight.to_dense(), bias) - return output \ No newline at end of file + output = F.linear(x, w.to_dense(), bias) + return output diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 33332e77ae8e2..1aa34a068f7a0 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -19,8 +19,7 @@ QuantizationConfig) from vllm.model_executor.layers.sparsity import (get_sparsity_config, SparsityConfig) -from vllm.model_executor.layers.parameters import (get_param_data, - SparseParameter) +from vllm.model_executor.layers.parameters import LazyCompressedParameter logger = init_logger(__name__) @@ -300,8 +299,8 @@ def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() - get_param_data(param).copy_(loaded_weight) - if isinstance(param, SparseParameter): + param.data.copy_(loaded_weight) + if isinstance(param, LazyCompressedParameter): param.pack() From 202e655cb445967eb73c5a11d2faa6fad789d53d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 22:45:02 -0500 Subject: [PATCH 14/21] disable the semi-sparse stuff temporarily --- .../layers/parameters/__init__.py | 5 +- .../layers/parameters/lazy_compressed.py | 68 ++++++++++--------- .../layers/sparsity/__init__.py | 2 +- .../sparsity/semi_structured_sparse_w16a16.py | 5 +- .../sparsity/sparse_w16a16_linear_method.py | 30 +++++--- 5 files changed, 62 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/parameters/__init__.py b/vllm/model_executor/layers/parameters/__init__.py index 5271295b2234b..c05cdf56e27a4 100644 --- a/vllm/model_executor/layers/parameters/__init__.py +++ b/vllm/model_executor/layers/parameters/__init__.py @@ -1,2 +1,5 @@ -import torch from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter + +__all__ = [ + "LazyCompressedParameter", +] diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py index 1224e467c56da..1ef776170d53c 100644 --- a/vllm/model_executor/layers/parameters/lazy_compressed.py +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -1,30 +1,24 @@ import numpy import torch -from torch.utils._pytree import tree_map, tree_flatten -from torch.sparse import SparseSemiStructuredTensor +from torch.utils._pytree import tree_map from typing import Type -from magic_wand import ( - SparseTensor, - CompressedStorageFormat, - SparseBitmaskStorageFormat, - SparseSemiStructuredStorageFormat -) +from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat) class LazyCompressedParameter(torch.Tensor): @staticmethod - def __new__( - cls, - uncompressed_data: torch.Tensor, - storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat, - compress_transposed: bool = False - ): - self = torch.Tensor._make_wrapper_subclass(cls, - size=uncompressed_data.shape, - dtype=uncompressed_data.dtype, - requires_grad=False) + def __new__(cls, + uncompressed_data: torch.Tensor, + storage_format_cls: Type[ + CompressedStorageFormat] = SparseBitmaskStorageFormat, + compress_transposed: bool = False): + self = torch.Tensor._make_wrapper_subclass( + cls, + size=uncompressed_data.shape, + dtype=uncompressed_data.dtype, + requires_grad=False) self.storage_format_cls = storage_format_cls self.compressed_data = None self.uncompressed_data = uncompressed_data @@ -32,7 +26,7 @@ def __new__( self._is_param = True return self - + @property def has_compressed_data(self) -> bool: return (self.compressed_data is not None) @@ -44,31 +38,41 @@ def has_uncompressed_data(self) -> bool: @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): ret_storage_format_cls = None - + def unwrap(e): nonlocal ret_storage_format_cls if isinstance(e, LazyCompressedParameter): assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls ret_storage_format_cls = e.storage_format_cls - return e.uncompressed_data if isinstance(e, LazyCompressedParameter) else e - + return e.uncompressed_data if isinstance( + e, LazyCompressedParameter) else e + rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) - + def wrap(e): - if isinstance(e, torch.Tensor) and ret_storage_format_cls is not None: - return LazyCompressedParameter(e, storage_format_cls=ret_storage_format_cls) + if isinstance(e, + torch.Tensor) and ret_storage_format_cls is not None: + return LazyCompressedParameter( + e, storage_format_cls=ret_storage_format_cls) return e rs = tree_map(wrap, rs) return rs def compress(self) -> None: - density = torch.count_nonzero(self.uncompressed_data).item() / numpy.prod(self.shape) - print("density: ", density, torch.count_nonzero(self.uncompressed_data), numpy.prod(self.shape)) - + density = torch.count_nonzero( + self.uncompressed_data).item() / numpy.prod(self.shape) + + # only compress if we have sufficient sparsity (>=45%), currently + # this applies globally across all formats including 2:4 + if (1 - density) < 0.45: + return + if self.uncompressed_data is None: - raise ValueError("Called compress() but uncompressed_data does not exist.") + raise ValueError( + "Called compress() but uncompressed_data does not exist.") self.compressed_data = self.storage_format_cls.compress( - self.uncompressed_data.t() if self.compress_transposed else self.uncompressed_data - ) - self.uncompressed_data = None \ No newline at end of file + self.uncompressed_data.t( + ) if self.compress_transposed else self.uncompressed_data) + del self.uncompressed_data # free memory + self.uncompressed_data = None diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py index 236620d130b8a..82893916fde80 100644 --- a/vllm/model_executor/layers/sparsity/__init__.py +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -6,7 +6,7 @@ _SPARSITY_CONFIG_REGISTRY = { "sparse_w16a16": SparseW16A16Config, - "semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config + "semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config, } diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py index 78a67cb1f4483..2cdd34fd0ff1c 100644 --- a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -8,8 +8,7 @@ class SemiStructuredSparseW16A16Config(SparsityConfig): - """Config class for SemiStructuredSparseW16A16. - """ + """Config class for SemiStructuredSparseW16A16.""" def __init__(self) -> None: pass @@ -27,7 +26,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half, torch.bfloat16] + return [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index b3e414c99cd80..ff3535e9b2b09 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -6,8 +6,7 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.parameters import LazyCompressedParameter -from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat, - SparseSemiStructuredStorageFormat) +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) from magic_wand.ops import be_ds_gemm @@ -28,14 +27,16 @@ def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - self.input_size_per_partition = input_size_per_partition - self.output_size_per_partition = output_size_per_partition + supports_linear = (self.storage_format_cls != + SparseBEGemmStorageFormat) weight = LazyCompressedParameter( torch.empty((output_size_per_partition, input_size_per_partition), dtype=params_dtype), storage_format_cls=self.storage_format_cls, - compress_transposed=True) + # if we don't support F.linear or something analogous, + # transpose when we compress so we can use a basic matmul + compress_transposed=not supports_linear) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -49,10 +50,16 @@ def apply_weights( ) -> torch.Tensor: w: LazyCompressedParameter = weights["weight"] - if self.storage_format_cls == SparseSemiStructuredStorageFormat: - output = F.linear(x, w, bias) - return output - if self.storage_format_cls == SparseBEGemmStorageFormat: + # if we never compressed (likely due to insufficient sparsity), i.e. have uncompressed_data run normally + if w.has_uncompressed_data: + assert not w.has_compressed_data + output = F.linear(x, w.uncompressed_data, bias) + # The current 2:4 implementation was running dense so ignore it + # for now and instead just explicitly decompress as usual + # elif self.storage_format_cls == SparseSemiStructuredStorageFormat: + # assert bias is None + # raise NotImplementedError + elif self.storage_format_cls == SparseBEGemmStorageFormat: assert bias is None assert w.compress_transposed out_shape = (x.shape[:-1] + (w.shape[0], )) @@ -62,5 +69,6 @@ def apply_weights( else: # Standard matrix multiply # Uncompress to dense - output = F.linear(x, w.to_dense(), bias) - return output + assert not w.compress_transposed + output = F.linear(x, w.compressed_data.decompress(), bias) + return output From 131a0a5fde6187937099944f73fdb04e16a193fe Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 22:53:05 -0500 Subject: [PATCH 15/21] format fix --- vllm/model_executor/layers/parameters/lazy_compressed.py | 2 +- .../layers/sparsity/sparse_w16a16_linear_method.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py index 1ef776170d53c..96e892a03d1fb 100644 --- a/vllm/model_executor/layers/parameters/lazy_compressed.py +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -63,7 +63,7 @@ def compress(self) -> None: density = torch.count_nonzero( self.uncompressed_data).item() / numpy.prod(self.shape) - # only compress if we have sufficient sparsity (>=45%), currently + # only compress if we have sufficient sparsity (>=45%), currently # this applies globally across all formats including 2:4 if (1 - density) < 0.45: return diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index b283b42c7e0ee..1420ee97be0ce 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -71,4 +71,3 @@ def apply_weights( assert not w.compress_transposed output = F.linear(x, w.compressed_data.decompress(), bias) return output - From 5c6a55e67d653f2b07c5654048d10b1b808cdfdf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 22:54:23 -0500 Subject: [PATCH 16/21] remove sparse benchmark --- benchmarks/simple_sparse_benchmark.py | 33 --------------------------- 1 file changed, 33 deletions(-) delete mode 100644 benchmarks/simple_sparse_benchmark.py diff --git a/benchmarks/simple_sparse_benchmark.py b/benchmarks/simple_sparse_benchmark.py deleted file mode 100644 index 7236cda5a12ff..0000000000000 --- a/benchmarks/simple_sparse_benchmark.py +++ /dev/null @@ -1,33 +0,0 @@ -from vllm import LLM, SamplingParams -import copy -import time -#os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use only cuda:0 - -model = LLM( - "nm-testing/zephyr-50sparse-24", - #sparsity="sparse_w16a16", # If left off, model will be loaded as dense - enforce_eager=True, # Does not work with cudagraphs yet - dtype="float16", # bfloat16 - tensor_parallel_size=2, - max_model_len=1024) - -num_prompts = 64 -input_len = 3072 * 2 -prompts = [ - copy.deepcopy(prompt) for prompt in (["Hi im a prompt"] * num_prompts) -] - -sampling_params = SamplingParams(max_tokens=100, temperature=0) - -start_time = time.time() - -outputs = model.generate(prompts, sampling_params=sampling_params) - -end_time = time.time() # Capture end time -duration = end_time - start_time # Calculate duration - -print(f"Elapsed time: {duration} seconds") - -#print(outputs[0]) -#print(outputs) -#print(outputs[0].outputs[0].text) From ae57f2c0f181c383c2b04579f09023e48ba96453 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 22:55:35 -0500 Subject: [PATCH 17/21] small format fix --- .../layers/sparsity/sparse_w16a16_linear_method.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 1420ee97be0ce..c4a6bcddf003f 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -49,7 +49,8 @@ def apply_weights( ) -> torch.Tensor: w: LazyCompressedParameter = weights["weight"] - # if we never compressed (likely due to insufficient sparsity), i.e. have uncompressed_data run normally + # if we never compressed (likely due to insufficient sparsity), + # i.e. have uncompressed_data run normally if w.has_uncompressed_data: assert not w.has_compressed_data output = F.linear(x, w.uncompressed_data, bias) From fb95394c0083a5c8ac7fd5f384ebd00aa6798c1b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 23:19:38 -0500 Subject: [PATCH 18/21] remove useless comments --- vllm/model_executor/layers/linear.py | 4 ---- .../layers/sparsity/sparse_w16a16_linear_method.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9af9b0c5f2f62..2901ec5773fc4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -206,7 +206,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If LazyCompressedParameter, compress the data. if isinstance(param, LazyCompressedParameter): param.compress() @@ -463,8 +462,6 @@ def weight_loader(self, # all of the shards are loaded, for the QKV matrix this means # loading shards "q", "k" and "v" all_shards_loaded = (self.loaded_shards == set(["q", "k", "v"])) - - # If LazyCompressedParameter, compress the data. if all_shards_loaded and isinstance(param, LazyCompressedParameter): param.compress() @@ -559,7 +556,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - # If LazyCompressedParameter, compress the data. if isinstance(param, LazyCompressedParameter): param.compress() diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index c4a6bcddf003f..65713a1bf15b3 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -49,7 +49,7 @@ def apply_weights( ) -> torch.Tensor: w: LazyCompressedParameter = weights["weight"] - # if we never compressed (likely due to insufficient sparsity), + # if we never compressed (likely due to insufficient sparsity), # i.e. have uncompressed_data run normally if w.has_uncompressed_data: assert not w.has_compressed_data From b5ffb39c53709c8662e20c86e36b14d64480cea4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 13 Feb 2024 23:20:59 -0500 Subject: [PATCH 19/21] cleanup spacing --- vllm/model_executor/layers/parameters/lazy_compressed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py index 96e892a03d1fb..f2747fba29902 100644 --- a/vllm/model_executor/layers/parameters/lazy_compressed.py +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -7,7 +7,6 @@ class LazyCompressedParameter(torch.Tensor): - @staticmethod def __new__(cls, uncompressed_data: torch.Tensor, From 9b69f56fc69011cc7085f30e556b0d08211f5907 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 14 Feb 2024 00:29:23 -0500 Subject: [PATCH 20/21] revert --- vllm/model_executor/layers/parameters/lazy_compressed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/parameters/lazy_compressed.py b/vllm/model_executor/layers/parameters/lazy_compressed.py index f2747fba29902..96e892a03d1fb 100644 --- a/vllm/model_executor/layers/parameters/lazy_compressed.py +++ b/vllm/model_executor/layers/parameters/lazy_compressed.py @@ -7,6 +7,7 @@ class LazyCompressedParameter(torch.Tensor): + @staticmethod def __new__(cls, uncompressed_data: torch.Tensor, From 1fbc82f708814a81e021f7f4e0c4a1e33be2a3e3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 14 Feb 2024 09:43:16 -0500 Subject: [PATCH 21/21] missed pack --- vllm/model_executor/weight_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 1aa34a068f7a0..cc8cadfebdf1c 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -301,7 +301,7 @@ def default_weight_loader(param: torch.nn.Parameter, assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) if isinstance(param, LazyCompressedParameter): - param.pack() + param.compress() def initialize_dummy_weights(