From 0baae1d791ad9a97285986839bc20e0087eeeedf Mon Sep 17 00:00:00 2001 From: Guy Jacob Date: Wed, 27 Feb 2019 12:08:21 +0200 Subject: [PATCH] Bugfix in pruning masks access when model was modified by a Quantizer --- distiller/scheduler.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index d8df5bf6e..238769dea 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -170,11 +170,24 @@ def apply_mask(self, is_forward=True): # the weights. self.zeros_mask_dict[name].apply_mask(param) except KeyError: - # Quantizers for training modify some model parameters by adding a prefix - # If this is the source of the error, workaround and move on + # Quantizers for training might modify model parameters in a couple of ways: + # 1. By adding a prefix to the parameter tensor name + # 2. By wrapping the module holding the parameter in a wrapper module + # If the source of the error is one of the above, workaround and move on + # + # Quantizers might also add new learnable parameters (e.g. the clip value in PACT quantization) + # These parameters will also be missing from the masks mapping. For now, we'll assume that we're + # not interested in pruning these parameters - and we just ignore them. + # + # TODO: This is not scalable at all. Find a solution that doesn't "hard-code" these conditions... name_parts = name.split('.') - if name_parts[-1].startswith(FP_BKP_PREFIX): - name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, , 1) + prefixed = name_parts[-1].startswith(FP_BKP_PREFIX) + wrapped = name_parts[-2] == 'wrapped_module' + if prefixed or wrapped: + if prefixed: + name_parts[-1] = name_parts[-1].replace(FP_BKP_PREFIX, , 1) + if wrapped: + name_parts.pop(-2) name = '.'.join(name_parts) self.zeros_mask_dict[name].apply_mask(param)