Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and loading pruning masks? #428

Open
lxaw opened this issue Oct 17, 2024 · 0 comments
Open

Saving and loading pruning masks? #428

lxaw opened this issue Oct 17, 2024 · 0 comments

Comments

@lxaw
Copy link

lxaw commented Oct 17, 2024

Thanks for the wonderful repository, it is really a great gift for the community.

I was wondering if there is any example code on how one can save / load a binary mask of the pruned modules. For instance, in this code:

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
    ignored_layers=ignored_layers,
)

# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
# finetune the pruned model here
# finetune(model)
# ...

...is it possible to save a binary mask of the selected and dropped modules, and then load the mask onto the model (without pruning)? Or to have a dictionary of the module name and if it was selected or dropped?

If this is not possible at the current moment, could you provide some hints as where to look in the code for a similar feature?

Thanks again, and I look forward to the growth of this repo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant