-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LRP for resnet model #4
Comments
I implement the resnet convert as follows: import torch
import torchvision
from lrp.conv import Conv2d
from lrp.linear import Linear
from lrp.sequential import Sequential, Bottleneck
conversion_table = {
'Linear': Linear,
'Conv2d': Conv2d
}
# # # # # Convert torch.models.resnetxx to lrp model
def convert_resnet(module, modules=None):
# First time
if modules is None:
modules = []
for m in module.children():
convert_resnet(m, modules=modules)
# if isinstance(m, torch.nn.Sequential):
# break
# Vgg model has a flatten, which is not represented as a module
# so this loop doesn't pick it up.
# This is a hack to make things work.
if isinstance(m, torch.nn.AdaptiveAvgPool2d):
modules.append(torch.nn.Flatten())
sequential = Sequential(*modules)
return sequential
# Recursion
if isinstance(module, torch.nn.Sequential):
for m in module.children():
convert_resnet(m, modules=modules)
elif isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
class_name = module.__class__.__name__
lrp_module = conversion_table[class_name].from_torch(module)
modules.append(lrp_module)
# maxpool is handled with gradient for the moment
elif isinstance(module, torch.nn.ReLU):
# avoid inplace operations. They might ruin PatternNet pattern
# computations
modules.append(torch.nn.ReLU())
elif isinstance(module, torchvision.models.resnet.Bottleneck):
# For torchvision Bottleneck
bottleneck = Bottleneck()
bottleneck.conv1 = Conv2d.from_torch(module.conv1)
bottleneck.conv2 = Conv2d.from_torch(module.conv2)
bottleneck.conv3 = Conv2d.from_torch(module.conv3)
bottleneck.bn1 = module.bn1
bottleneck.bn2 = module.bn2
bottleneck.bn3 = module.bn3
bottleneck.relu = torch.nn.ReLU()
if module.downsample is not None:
bottleneck.downsample = module.downsample
bottleneck.downsample[0] = Conv2d.from_torch(module.downsample[0])
modules.append(bottleneck)
else:
modules.append(module) and edit the sequential.py as follows: import torch
from . import Linear, Conv2d
from .maxpool import MaxPool2d
from .functional.utils import normalize
def grad_decorator_fn(module):
"""
Currently not used but can be used for debugging purposes.
"""
def fn(x):
return normalize(x)
return fn
avoid_normalization_on = ['relu', 'maxp']
def do_normalization(rule, module):
if "pattern" not in rule.lower(): return False
return not str(module)[:4].lower() in avoid_normalization_on
def is_kernel_layer(module):
return isinstance(module, Conv2d) or isinstance(module, Linear) or isinstance(module, Bottleneck)
def is_rule_specific_layer(module):
return isinstance(module, MaxPool2d)
class Sequential(torch.nn.Sequential):
def forward(self, input, explain=False, rule="epsilon", pattern=None):
if not explain: return super(Sequential, self).forward(input)
first = True
# copy references for user to be able to reuse patterns
if pattern is not None: pattern = list(pattern)
for module in self:
if do_normalization(rule, module):
input.register_hook(grad_decorator_fn(module))
if is_kernel_layer(module):
P = None
if pattern is not None:
P = pattern.pop(0)
input = module.forward(input, explain=True, rule=rule, pattern=P)
elif is_rule_specific_layer(module):
input = module.forward(input, explain=True, rule=rule)
else: # Use gradient as default for remaining layer types
input = module(input)
first = False
if do_normalization(rule, module):
input.register_hook(grad_decorator_fn(module))
return input
class Bottleneck(torch.nn.Module):
def __init__(self):
super(Bottleneck, self).__init__()
self.downsample = None
def forward(self, x, explain=True, rule="epsilon", pattern=None):
identity = x
if pattern is not None:
out = self.conv1(x, explain=explain, rule=rule, pattern=pattern[0])
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out, explain=explain, rule=rule, pattern=pattern[1])
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out, explain=explain, rule=rule, pattern=pattern[2])
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample[0](x, explain, rule, pattern=pattern[3])
identity = self.downsample[1](identity)
else:
out = self.conv1(x, explain=explain, rule=rule)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out, explain=explain, rule=rule)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out, explain=explain, rule=rule)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample[0](x, explain, rule)
identity = self.downsample[1](identity)
out += identity
out = self.relu(out)
return out For patternnet, also need to modify the _fit_pattern function in patterns.py: def _fit_pattern(model, train_loader, max_iter, device, mask_fn = lambda y: torch.ones_like(y)):
stats_x = []
stats_y = []
stats_xy = []
weights = []
cnt = []
cnt_all = []
first = True
for b, (x, _) in enumerate(tqdm(train_loader)):
x = x.to(device)
i = 0
for m in model:
# For Bottleneck
if isinstance(m, Bottleneck):
if first:
stats_x.append([])
stats_y.append([])
stats_xy.append([])
weights.append([])
y = m.conv1(x)
mask = mask_fn(y).float().to(device)
if m.conv1.bias is not None:
y_wo_bias = y - m.conv1.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv1, x, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][0].update(x_, cnt_)
stats_y[i][0].update(y_.sum(0), cnt_all_)
stats_xy[i][0].update(xy_, cnt_)
x1 = y.clone()
x1 = m.bn1(x1)
x1 = m.relu(x1)
y = m.conv2(x1)
mask = mask_fn(y).float().to(device)
if m.conv2.bias is not None:
y_wo_bias = y - m.conv2.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv2, x1, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][1].update(x_, cnt_)
stats_y[i][1].update(y_.sum(0), cnt_all_)
stats_xy[i][1].update(xy_, cnt_)
x2 = y.clone()
x2 = m.bn2(x2)
x2 = m.relu(x2)
y = m.conv3(x2)
mask = mask_fn(y).float().to(device)
if m.conv3.bias is not None:
y_wo_bias = y - m.conv3.bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv3, x2, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][2].update(x_, cnt_)
stats_y[i][2].update(y_.sum(0), cnt_all_)
stats_xy[i][2].update(xy_, cnt_)
y = m.bn3(y)
if m.downsample is not None:
identity = m.downsample[0](x)
mask = mask_fn(identity).float().to(device)
if m.downsample[0].bias is not None:
y_wo_bias = y - m.downsample[0].bias.view(-1, 1, 1)
else:
y_wo_bias = y.clone()
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.downsample[0], x, y_wo_bias, mask)
if first:
stats_x[i].append(RunningMean(x_.shape, device))
stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
stats_xy[i].append(RunningMean(xy_.shape, device))
weights[i].append((w, w_fn))
stats_x[i][3].update(x_, cnt_)
stats_y[i][3].update(y_.sum(0), cnt_all_)
stats_xy[i][3].update(xy_, cnt_)
identity = m.downsample[1](identity)
y += identity
x = m.relu(y)
i += 1
continue
y = m(x) # Note, this includes bias.
if not (isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d)):
x = y.clone()
continue
mask = mask_fn(y).float().to(device)
if m.bias is not None:
if isinstance(m, torch.nn.Conv2d):
y_wo_bias = y - m.bias.view(-1, 1, 1)
else:
y_wo_bias = y - m.bias.clone()
else:
y_wo_bias = y
cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m, x, y_wo_bias, mask)
if first:
stats_x.append(RunningMean(x_.shape, device))
stats_y.append(RunningMean(y_.shape, device)) # Use all y
stats_xy.append(RunningMean(xy_.shape, device))
weights.append((w, w_fn))
stats_x[i].update(x_, cnt_)
stats_y[i].update(y_.sum(0), cnt_all_)
stats_xy[i].update(xy_, cnt_)
x = y.clone()
i += 1
first = False
if max_iter is not None and b+1 == max_iter: break
def pattern(x_mean, y_mean, xy_mean, W2d):
x_ = x_mean.value
y_ = y_mean.value
xy_ = xy_mean.value
W, w_fn = W2d
ExEy = x_ * y_
cov_xy = xy_ - ExEy # [in, out]
w_cov_xy = torch.diag(W @ cov_xy) # [out,]
A = safe_divide(cov_xy, w_cov_xy[None, :])
A = w_fn(A) # Reshape to original kernel size
return A
# patterns = [pattern(*vars) for vars in zip(stats_x, stats_y, stats_xy, weights)]
patterns = []
for vars in zip(stats_x, stats_y, stats_xy, weights):
if isinstance(vars[0], RunningMean):
patterns.append(pattern(*vars))
else:
patterns_sub = []
for vars_sub in zip(vars[0], vars[1], vars[2], vars[3]):
patterns_sub.append(pattern(*vars_sub))
patterns.append(patterns_sub)
return patterns The LRP for the adding manipulation is not added yet, will probably need to consider implementing this. |
Thank you for your posting @sdw95927 . |
I can see the patterns from ResNet too, just not as clear as VGG. I think it's mainly due to the complex structure, such as residual connection, in ResNet, whereas VGG is simple and straightforward. |
@miladsikaroudi Can you please share what you did to make this work with ResNet? |
Thank for your works!
I see that you implement LRP for vgg model. But vgg is simple model with single Sequential and does not have residual connection. Could you help me to implement LRP for complex model, such as ResNet?
Thank you so much!
The text was updated successfully, but these errors were encountered: