diff --git a/tests/conftest.py b/tests/conftest.py index 1d67524..a31a445 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,21 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d +def pytest_addoption(parser): + '''Add options to pytest.''' + parser.addoption( + '--batchsize', + default=4, + help='Batch-size for generated samples.' + ) + + +def pytest_generate_tests(metafunc): + '''Generate test fixture values based on CLI options.''' + if 'batchsize' in metafunc.fixturenames: + metafunc.parametrize('batchsize', [metafunc.config.getoption('batchsize')], scope='session') + + def prodict(**kwargs): '''Create a dictionary with values which are the cartesian product of the input keyword arguments.''' return [dict(zip(kwargs, val)) for val in product(*kwargs.values())] @@ -30,6 +45,23 @@ def rng(request): return torch.manual_seed(request.param) +@pytest.fixture( + scope='session', + params=[ + (torch.nn.ReLU, {}), + (torch.nn.Softmax, dict(dim=1)), + (torch.nn.Tanh, {}), + (torch.nn.Sigmoid, {}), + (torch.nn.Softplus, dict(beta=1)), + ], + ids=lambda param: param[0].__name__ +) +def module_simple(rng, request): + '''Fixture for simple modules.''' + module_type, kwargs = request.param + return module_type(**kwargs).to(torch.float64).eval() + + @pytest.fixture( scope='session', params=[ @@ -83,9 +115,9 @@ def module_batchnorm(module_linear): @pytest.fixture(scope='session') -def data_input(rng, module_linear): +def data_linear(rng, batchsize, module_linear): '''Fixture to create data for a linear module, given an RNG.''' - shape = (4,) + shape = (batchsize,) setups = [ (Conv1d, 1, 1), (ConvTranspose1d, 0, 1), @@ -102,3 +134,15 @@ def data_input(rng, module_linear): shape += (module_linear.weight.shape[dim],) + (4,) * ndims return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng) + + +@pytest.fixture(scope='session', params=[ + (16,), + (4,), + (4, 4), + (4, 4, 4), +]) +def data_simple(request, rng, batchsize): + '''Fixture to create data for a linear module, given an RNG.''' + shape = (batchsize,) + request.param + return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng) diff --git a/tests/test_canonizers.py b/tests/test_canonizers.py index 047b837..f820572 100644 --- a/tests/test_canonizers.py +++ b/tests/test_canonizers.py @@ -4,20 +4,20 @@ from zennit.canonizers import SequentialMergeBatchNorm -def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_input): +def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_linear): '''Test whether the output of the merged batchnorm is consistent with its original output.''' - output_linear_before = module_linear(data_input) + output_linear_before = module_linear(data_linear) output_batchnorm_before = module_batchnorm(output_linear_before) canonizer = SequentialMergeBatchNorm() try: canonizer.register((module_linear,), module_batchnorm) - output_linear_canonizer = module_linear(data_input) + output_linear_canonizer = module_linear(data_linear) output_batchnorm_canonizer = module_batchnorm(output_linear_canonizer) finally: canonizer.remove() - output_linear_after = module_linear(data_input) + output_linear_after = module_linear(data_linear) output_batchnorm_after = module_batchnorm(output_linear_after) assert all(torch.allclose(left, right, atol=1e-5) for left, right in [ diff --git a/tests/test_rules.py b/tests/test_rules.py index 916e034..78688b6 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -7,6 +7,7 @@ import pytest import torch from zennit.rules import Epsilon, ZPlus, AlphaBeta, Gamma, ZBox, Norm, WSquare, Flat +from zennit.rules import Pass, ReLUDeconvNet, ReLUGuidedBackprop def stabilize(input, epsilon=1e-6): @@ -22,14 +23,15 @@ def as_matrix(module_linear, input, output): return weight, bias -RULEPAIRS = [] +RULES_LINEAR = [] +RULES_SIMPLE = [] -def replicates(replicated_func, **kwargs): +def replicates(target_list, replicated_func, **kwargs): '''Decorator to indicate a replication of a function for testing.''' def wrapper(func): - '''Append to ``RULEPAIRS`` as partial, given ``kwargs``.''' - RULEPAIRS.append( + '''Append to ``RULES_LINEAR`` as partial, given ``kwargs``.''' + target_list.append( pytest.param( (partial(replicated_func, **kwargs), partial(func, **kwargs)), id=replicated_func.__name__ @@ -68,16 +70,31 @@ def wrapped(module_linear, input, output, **kwargs): return wrapped -@replicates(Epsilon, epsilon=1e-6) -@replicates(Epsilon, epsilon=1.0) -@replicates(Norm) +def with_grad(func): + '''Decorator to wrap function such that the gradient is computed and passed to the function instead of module.''' + @wraps(func) + def wrapped(module, input, output, **kwargs): + '''Get gradient and pass along input, output and keyword arguments to func.''' + gradient, = torch.autograd.grad(module(input), input, output) + return func( + gradient, + input, + output, + **kwargs + ) + return wrapped + + +@replicates(RULES_LINEAR, Epsilon, epsilon=1e-6) +@replicates(RULES_LINEAR, Epsilon, epsilon=1.0) +@replicates(RULES_LINEAR, Norm) @matrix_form def rule_epsilon(weight, bias, input, relevance, epsilon=1e-6): '''Replicates the Epsilon rule.''' return input * ((relevance / stabilize(input @ weight.t() + bias, epsilon)) @ weight) -@replicates(ZPlus) +@replicates(RULES_LINEAR, ZPlus) @matrix_form def rule_zplus(weight, bias, input, relevance): '''Replicates the ZPlus rule.''' @@ -90,8 +107,8 @@ def rule_zplus(weight, bias, input, relevance): return xplus * (rfac @ wplus) + xminus * (rfac @ wminus) -@replicates(Gamma, gamma=0.25) -@replicates(Gamma, gamma=0.5) +@replicates(RULES_LINEAR, Gamma, gamma=0.25) +@replicates(RULES_LINEAR, Gamma, gamma=0.5) @matrix_form def rule_gamma(weight, bias, input, relevance, gamma): '''Replicates the Gamma rule.''' @@ -100,8 +117,8 @@ def rule_gamma(weight, bias, input, relevance, gamma): return input * ((relevance / stabilize(input @ wgamma.t() + bgamma)) @ wgamma) -@replicates(AlphaBeta, alpha=2.0, beta=1.0) -@replicates(AlphaBeta, alpha=1.0, beta=0.0) +@replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0) +@replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0) @matrix_form def rule_alpha_beta(weight, bias, input, relevance, alpha, beta): '''Replicates the AlphaBeta rule.''' @@ -118,7 +135,7 @@ def rule_alpha_beta(weight, bias, input, relevance, alpha, beta): return alpha * result_alpha - beta * result_beta -@replicates(ZBox, low=-3.0, high=3.0) +@replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0) @matrix_form def rule_zbox(weight, bias, input, relevance, low, high): '''Replicates the ZBox rule.''' @@ -131,7 +148,7 @@ def rule_zbox(weight, bias, input, relevance, low, high): return input * (rfac @ weight) - low * (rfac @ wplus) - high * (rfac @ wminus) -@replicates(WSquare) +@replicates(RULES_LINEAR, WSquare) @matrix_form def rule_wsquare(weight, bias, input, relevance): '''Replicates the WSquare rule.''' @@ -141,7 +158,7 @@ def rule_wsquare(weight, bias, input, relevance): return rfac @ wsquare -@replicates(Flat) +@replicates(RULES_LINEAR, Flat) @flat_module_params @matrix_form def rule_flat(wflat, bias, input, relevance): @@ -151,25 +168,59 @@ def rule_flat(wflat, bias, input, relevance): return rfac @ wflat -@pytest.fixture(scope='session', params=RULEPAIRS) -def rule_pair(request): - '''Fixture to supply ``RULEPAIRS``.''' +@replicates(RULES_SIMPLE, Pass) +def rule_pass(module, input, relevance): + '''Replicates the Pass rule.''' + return relevance + + +@replicates(RULES_SIMPLE, ReLUDeconvNet) +def rule_relu_deconvnet(module, input, relevance): + '''Replicates the ReLUDeconvNet rule.''' + return relevance.clamp(min=0) + + +@replicates(RULES_SIMPLE, ReLUGuidedBackprop) +@with_grad +def rule_relu_guidedbackprop(gradient, input, relevance): + '''Replicates the ReLUGuidedBackprop rule.''' + return gradient * (relevance > 0.) + + +@pytest.fixture(scope='session', params=RULES_LINEAR) +def rule_pair_linear(request): + '''Fixture to supply ``RULES_LINEAR``.''' return request.param -def test_linear_rule(module_linear, data_input, rule_pair): - '''Test whether replicated and original implementations of rules for linear layers agree.''' +@pytest.fixture(scope='session', params=RULES_SIMPLE) +def rule_pair_simple(request): + '''Fixture to supply ``RULES_SIMPLE``.''' + return request.param + + +def compare_rule_pair(module, data, rule_pair): + '''Compare rules with their replicated versions.''' rule_hook, rule_replicated = rule_pair - input = data_input.clone().requires_grad_() - handle = rule_hook().register(module_linear) + input = data.clone().requires_grad_() + handle = rule_hook().register(module) try: - output = module_linear(input) + output = module(input) relevance_hook, = torch.autograd.grad(output, input, grad_outputs=output) finally: handle.remove() - with torch.no_grad(): - relevance_replicated = rule_replicated(module_linear, input, output) + relevance_replicated = rule_replicated(module, input, output) assert torch.allclose(relevance_hook, relevance_replicated, atol=1e-5) + + +def test_linear_rule(module_linear, data_linear, rule_pair_linear): + '''Test whether replicated and original implementations of rules for linear layers agree.''' + compare_rule_pair(module_linear, data_linear, rule_pair_linear) + + +def test_simple_rule(module_simple, data_simple, rule_pair_simple): + '''Test whether replicated and original implementations of rules for simple layers agree.''' + compare_rule_pair(module_simple, data_simple, rule_pair_simple)