Skip to content

Commit

Permalink
Fixed variance redistribution function.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 8, 2019
1 parent 93db164 commit b3f7ff3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
## Features:
- Added FP16 support. Any model can now be run in 16-bit by passing the [apex](https://github.com/NVIDIA/apex) `FP16_Optimizer` into the `Masking` class and replacing `loss.backward()` with `optimizer.backward(loss)`.
- Added adapted [Dynamic Sparse Reparameterization](https://arxiv.org/abs/1902.05967) [codebase](https://github.com/IntelAI/dynamic-reparameterization) that works with sparse momentum.
- Added modular architecture for growth/prune/redistribution algorithms which is decoupled from the main library. This enables you to write your own prune/growth/redistribution algorithms without touched the library internals. A tutorial on how to add your own functions was also added: [How to Add Your Own Algorithms](How_to_add_your_own_algorithms.md]).
- Added modular architecture for growth/prune/redistribution algorithms which is decoupled from the main library. This enables you to write your own prune/growth/redistribution algorithms without touched the library internals. A tutorial on how to add your own functions was also added: [How to Add Your Own Algorithms](How_to_add_your_own_algorithms.md).
34 changes: 20 additions & 14 deletions How_to_add_your_own_algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,33 @@ masking.total_removed

Here I added two example extensions for redistribution and pruning. These two examples look at the variance of the gradient. If we look at weights with high and low variance in their gradients over time, then we can have the following interpretations.

For high variance weights, we can have two perspectives. The first one would assume that weights with high variance are unable to model the interactions in the inputs to classify the outputs due to a lack of capacity. For example a weight might have a problem to be useful for both the digit 0 and digit 7 when classifying MNIST and thus has high variance between these examples. If we add capacity to high variance layers, then we should reduce the variance over time (one weight for 7 one weight for 0). According to this perspective we want to add more parameters to layers with high average variance. In other words, we want to redistribute pruned parameters to layers with high gradient variance.
For high variance weights, we can have two perspectives. The first one would assume that weights with high variance are unable to model the interactions in the inputs to classify the outputs due to a lack of capacity. For example a weight might have a problem to be useful for both the digit 0 and digit 7 when classifying MNIST and thus has high variance between these examples. If we add capacity to high variance layers, then we should reduce the variance over time meaning the new weights can now fully model the different classes (one weight for 7 one weight for 0). According to this perspective we want to add more parameters to layers with high average variance. In other words, we want to redistribute pruned parameters to layers with high gradient variance.

The second perspective is a "potential of be useful" perspective. Here we see weights with high variance as having "potential to do the right classification, but they might just not have found the right decision boundary between classes yet". For example, a weight might have problems being useful for both the digit 7 and 0 but overtime it can find a feature which is useful for both classes. Thus gradient variance should reduce over time as features become more stable. If we take this perspective then it is important to keep some medium-to-high variance weights. Low variance weights have "settled in" and follow the gradient for a specific set of classes. These weights will not change much anymore while high variance weights might change a lot. So high variance weights might have "potential" while the potential of low variance weights is easily assessed by looking at the magnitude of that weights. Thus we might improve pruning if we look at both the variance of the gradient _and_ the magnitude of weights. You can find these examples in ['mnist_cifar/extensions.py']('mnist_cifar/extensions.py').
The second perspective is a "potential of be useful" perspective. Here we see weights with high variance as having "potential to do the right classification, but they might just not have found the right decision boundary between classes yet". For example, a weight might have problems being useful for both the digit 7 and 0 but overtime it can find a feature which is useful for both classes. Thus gradient variance should reduce over time as features become more stable. If we take this perspective then it is important to keep some medium-to-high variance weights. Low variance weights have "settled in" and follow the gradient for a specific set of classes. These weights will not change much anymore while high variance weights might change a lot. So high variance weights might have "potential" while the potential of low variance weights is easily assessed by looking at the magnitude of that weights. Thus we might improve pruning if we look at both the variance of the gradient _and_ the magnitude of weights. You can find these examples in ['mnist_cifar/extensions.py']('sparse_learning/mnist_cifar/extensions.py').

### Implementation

```python
def variance_redistribution(masking, name, weight, mask):
'''Return the mean variance of existing weights.
Higher variance means the layer does not have enough
capacity to model the inputs with the number of current weights.
If weights stabilize this means that some weights might
be useless/not needed.
Intuition: Higher gradient variance means a layer does not have enough
capacity to model the inputs with the current number of weights.
Thus we want to add more weights if we have higher variance.
If variance of the gradient stabilizes this means
that some weights might be useless/not needed.
'''
layer_importance = torch.var(weight.grad[mask.byte()]).mean().item()
# Adam calculates the running average of the sum of square for us
# This is similar to RMSProp.
if 'exp_avg_sq' not in masking.optimizer.state[weight]:
print('Variance redistribution requires the adam optimizer to be run!')
raise Exception('Variance redistribution requires the adam optimizer to be run!')
iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq'])

layer_importance = iv_adam_sumsq[mask.byte()].mean().item()
return layer_importance


def magnitude_variance_pruning(masking, mask, weight, name):
''' Prunes weights which have high gradient variance and low magnitude.
Expand Down Expand Up @@ -135,13 +144,10 @@ Running 10 additional iterations (add `--iters 10`) of our new method with 5% we
```bash
python get_results_from_logs.py

Test set results for log: ./logs/lenet5_0.05_520776ed.log
Arguments:
augment=False, batch_size=100, bench=False, data='mnist', decay_frequency=25000, dense=False, density=0.05, epochs=100, fp16=False, growth='momentum', iters=10, l1=0.0, l2=0.0005, log_interval=100, lr=0.001, model='lenet5', momentum=0.9, no_cuda=False, optimizer='adam', prune='magnitude_variance', prune_rate=0.5, redistribution='variance', resume=None, save_features=False, save_model='./models/model.pt', seed=17, start_epoch=1, test_batch_size=100, valid_split=0.1, verbose=True
Accuracy. Median: 0.99300, Mean: 0.99300, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.99262,0.99338)
Error. Median: 0.00700, Mean: 0.00700, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.00662,0.00738)
Loss. Median: 0.02200, Mean: 0.02175, Standard Error: 0.00027, Sample size: 11, 95% CI: (0.02122,0.02228)

Accuracy. Mean: 0.99349, Standard Error: 0.00013, Sample size: 11, 95% CI: (0.99323,0.99375)
Error. Mean: 0.00651, Standard Error: 0.00013, Sample size: 11, 95% CI: (0.00625,0.00677)
Loss. Mean: 0.02078, Standard Error: 0.00035, Sample size: 11, 95% CI: (0.02010,0.02146)
```

Sparse momentum achieves an error of 0.0069 for this setting and the lower 95% confidence interval is 0.00649. Thus for this setting our results overlap with the confidence intervals of sparse momentum. Thus our new variance method is _as good or better_ than sparse momentum for this particular problem (Caffe LeNet-5 with 5% weights on MNIST).
Sparse momentum achieves an error of 0.0069 for this setting and the upper 95% confidence interval is 0.00739. Thus for this setting our results overlap with the confidence intervals of sparse momentum. Thus our new variance method is _as good_ as sparse momentum for this particular problem (Caffe LeNet-5 with 5% weights on MNIST).
18 changes: 13 additions & 5 deletions mnist_cifar/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,20 @@ def your_redistribution(masking, name, weight, mask):
def variance_redistribution(masking, name, weight, mask):
'''Return the mean variance of existing weights.
Higher variance means the layer does not have enough
capacity to model the inputs with the number of current weights.
If weights stabilize this means that some weights might
be useless/not needed.
Higher gradient variance means a layer does not have enough
capacity to model the inputs with the current number of weights.
Thus we want to add more weights if we have higher variance.
If variance of the gradient stabilizes this means
that some weights might be useless/not needed.
'''
layer_importance = torch.var(weight.grad[mask.byte()]).mean().item()
# Adam calculates the running average of the sum of square for us
# This is similar to RMSProp.
if 'exp_avg_sq' not in masking.optimizer.state[weight]:
print('Variance redistribution requires the adam optimizer to be run!')
raise Exception('Variance redistribution requires the adam optimizer to be run!')
iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq'])

layer_importance = iv_adam_sumsq[mask.byte()].mean().item()
return layer_importance


Expand Down
5 changes: 3 additions & 2 deletions mnist_cifar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def main():

args = parser.parse_args()
setup_logger(args)
print_and_log(args)

if args.fp16:
try:
Expand Down Expand Up @@ -202,13 +203,13 @@ def main():

# add custom prune/growth/redisribution here
if args.prune == 'magnitude_variance':
print('Using magnitude-variance pruning. Switching to Adam optimizer...')
args.prune = magnitude_variance_pruning
args.optimizer = 'adam'
args.lr /= 100.0
if args.redistribution == 'variance':
print('Using variance redistribution. Switching to Adam optimizer...')
args.redistribution = variance_redistribution
args.optimizer = 'adam'
args.lr /= 100.0

optimizer = None
if args.optimizer == 'sgd':
Expand Down

0 comments on commit b3f7ff3

Please sign in to comment.