diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 7e92bda..464e121 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -418,8 +418,8 @@ def _prune(self) -> typing.Generator: if group is None: continue ch_groups = self._get_channel_groups(group) imp = self.estimate_importance(group) # raw importance score - group_size = len(imp) // ch_groups if imp is None: continue + group_size = len(imp) // ch_groups if ch_groups > 1: # layers with dimension grouping, such as GroupConv, GroupNorm, Multi-head attention, etc. # We average importance across groups here. For example: # imp = [1, 2, 3, 4, 5, 6] with ch_groups=2.