Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenabreu7 committed Oct 20, 2023
1 parent fe7188a commit a21819f
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ def extractor(m):
}


def test_ignore_batch_dim():
model = nn.Linear(3, 1)

def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)
assert np.alltrue(g.nodes["output"].output_type["output"] == np.array([1]))


def test_ignore_time_and_batch_dim():
model = nn.Linear(3, 1)

def extractor(module: nn.Module):
return nir.Affine(module.weight, module.bias)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))
assert g.nodes["model"].weight.shape == (1, 3)

raw_input_shape = (1, 10, 3)
g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1])
exp_input_shape = (3,)
assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))


# def test_extract_stateful():
# model = norse.SequentialState(norse.LIFBoxCell(), nn.Linear(3, 1))

Expand Down

0 comments on commit a21819f

Please sign in to comment.