Skip to content

Commit

Permalink
Merge pull request #103 from april-tools/axes_order
Browse files Browse the repository at this point in the history
Swap axes order -- data flow `(C)FKB`
  • Loading branch information
lkct authored Jul 9, 2023
2 parents 0c8a2bd + 1985e09 commit 63e9140
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 36 deletions.
12 changes: 6 additions & 6 deletions cirkit/layers/einsum/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__( # type: ignore[misc]
super().__init__(rg_nodes, num_input_units, num_output_units)
self.prod_exp = prod_exp

self.params_left = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_right = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_out = nn.Parameter(torch.empty(num_output_units, rank, len(rg_nodes)))
self.params_left = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_right = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_out = nn.Parameter(torch.empty(len(rg_nodes), rank, num_output_units))

# TODO: get torch.default_float_dtype
# (float ** float) is not guaranteed to be float, but here we know it is
Expand All @@ -55,13 +55,13 @@ def __init__( # type: ignore[misc]

# TODO: use bmm to replace einsum? also axis order?
def _forward_left_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bip,irp->brp", x, self.params_left)
return torch.einsum("fkr,fkb->frb", self.params_left, x)

def _forward_right_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bip,irp->brp", x, self.params_right)
return torch.einsum("fkr,fkb->frb", self.params_right, x)

def _forward_out_linear(self, x: Tensor) -> Tensor:
return torch.einsum("brp,orp->bop", x, self.params_out)
return torch.einsum("frk,frb->fkb", self.params_out, x)

def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor:
left_hidden = self._forward_left_linear(left)
Expand Down
6 changes: 3 additions & 3 deletions cirkit/layers/einsum/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_compon

# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(num_output_units, len(rg_nodes), max_components))
self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units))
# TODO: what's the use of params_mask?
self.register_buffer("params_mask", torch.ones_like(self.params))
self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
Expand All @@ -90,7 +90,7 @@ def apply_params_mask(self) -> None:
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bonc,onc->bon", x, self.params)
return torch.einsum("cfk,cfkb->fkb", self.params, x)

# TODO: make forward return something
# pylint: disable-next=arguments-differ
Expand All @@ -106,6 +106,6 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override]
# TODO: use a mul or gather? or do we need this?
assert (self.params * self.params_mask == self.params).all()

return log_func_exp(log_input, func=self._forward_linear, dim=3, keepdim=False)
return log_func_exp(log_input, func=self._forward_linear, dim=0, keepdim=False)

# TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack
6 changes: 4 additions & 2 deletions cirkit/layers/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
# I try with an einsum operation (x, o, d, s), (b, x, s) -> b, x, o, d.
# That should have the same result

crucial_quantity_einsum = torch.einsum("xods,bxs->bxod", theta, self.suff_stats)
crucial_quantity_einsum = torch.einsum("xkds,bxs->bxkd", theta, self.suff_stats)

# assert not torch.isnan(crucial_quantity_einsum).any()

Expand Down Expand Up @@ -205,7 +205,9 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
self.marginalization_mask = None
output = self.ll

return torch.einsum("bxir,xro->bio", output, self.scope_tensor)
# why bxkr instead of bxkd?
# TODO: the axes order for input layer? better remove this contiguous
return torch.einsum("bxkr,xrf->fkb", output, self.scope_tensor).contiguous()

# TODO: how to fix?
# pylint: disable-next=arguments-differ
Expand Down
14 changes: 7 additions & 7 deletions cirkit/models/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__( # type: ignore[misc]
region.einet_address.layer = mixing_layer
region.einet_address.idx = reg_idx
mixing_layer.apply_params_mask()
self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx)))
self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx).T))

# TODO: can we annotate a list here?
# TODO: actually we should not mix all the input/mix/ein different types in one list
Expand Down Expand Up @@ -284,10 +284,10 @@ def forward(self, x: Tensor) -> Tensor:
assert isinstance(left_addr, tuple) and isinstance(right_addr, tuple)
# TODO: we should use dim=2, check all code
# TODO: duplicate code
log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=2)
log_left_prob = log_left_prob[:, :, left_addr[1]]
log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=2)
log_right_prob = log_right_prob[:, :, right_addr[1]]
log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=0)
log_left_prob = log_left_prob[left_addr[1]]
log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=0)
log_right_prob = log_right_prob[right_addr[1]]
out = inner_layer(log_left_prob, log_right_prob)
elif isinstance(inner_layer, EinsumMixingLayer):
_, padded_idx = self.bookkeeping[idx]
Expand All @@ -298,13 +298,13 @@ def forward(self, x: Tensor) -> Tensor:
# outputs[self.inner_layers[idx - 1]] = F.pad(
# outputs[self.inner_layers[idx - 1]], [0, 1], "constant", float("-inf")
# )
log_input_prob = outputs[self.inner_layers[idx - 1]][:, :, padded_idx]
log_input_prob = outputs[self.inner_layers[idx - 1]][padded_idx]
out = inner_layer(log_input_prob)
else:
assert False
outputs[inner_layer] = out

return outputs[self.inner_layers[-1]][:, :, 0]
return outputs[self.inner_layers[-1]][0].T # return shape (B, K)

# TODO: and what's the meaning of this?
# def backtrack(self, num_samples=1, class_idx=0, x=None, mode='sampling', **kwargs):
Expand Down
38 changes: 20 additions & 18 deletions tests/models/test_einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def _get_einet() -> TensorizedPC:
def _get_param_shapes() -> Dict[str, Tuple[int, ...]]:
return {
"input_layer.params": (4, 1, 1, 2),
"inner_layers.0.params_left": (1, 1, 4),
"inner_layers.0.params_right": (1, 1, 4),
"inner_layers.0.params_out": (1, 1, 4),
"inner_layers.1.params_left": (1, 1, 2),
"inner_layers.1.params_right": (1, 1, 2),
"inner_layers.1.params_out": (1, 1, 2),
"inner_layers.2.params": (1, 1, 2),
"inner_layers.0.params_left": (4, 1, 1),
"inner_layers.0.params_right": (4, 1, 1),
"inner_layers.0.params_out": (4, 1, 1),
"inner_layers.1.params_left": (2, 1, 1),
"inner_layers.1.params_right": (2, 1, 1),
"inner_layers.1.params_out": (2, 1, 1),
"inner_layers.2.params": (2, 1, 1),
}


Expand All @@ -109,15 +109,15 @@ def _set_params(einet: TensorizedPC) -> None:
[math.log(3), 0], # type: ignore[misc] # 3/4, 1/4
]
).reshape(4, 1, 1, 2),
"inner_layers.0.params_left": torch.ones(1, 1, 4) / 2,
"inner_layers.0.params_right": torch.ones(1, 1, 4) * 2,
"inner_layers.0.params_out": torch.ones(1, 1, 4),
"inner_layers.1.params_left": torch.ones(1, 1, 2) * 2,
"inner_layers.1.params_right": torch.ones(1, 1, 2) / 2,
"inner_layers.1.params_out": torch.ones(1, 1, 2),
"inner_layers.0.params_left": torch.ones(4, 1, 1) / 2,
"inner_layers.0.params_right": torch.ones(4, 1, 1) * 2,
"inner_layers.0.params_out": torch.ones(4, 1, 1),
"inner_layers.1.params_left": torch.ones(2, 1, 1) * 2,
"inner_layers.1.params_right": torch.ones(2, 1, 1) / 2,
"inner_layers.1.params_out": torch.ones(2, 1, 1),
"inner_layers.2.params": torch.tensor(
[1 / 3, 2 / 3], # type: ignore[misc]
).reshape(1, 1, 2),
).reshape(2, 1, 1),
}
)
einet.load_state_dict(state_dict) # type: ignore[misc]
Expand Down Expand Up @@ -159,9 +159,9 @@ def test_einet_partition_func() -> None:
@pytest.mark.parametrize( # type: ignore[misc]
"rg_cls,kwargs,log_answer",
[
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.188161849975586),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 51.31766128540039),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.198360443115234),
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.935434341430664),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 44.412864685058594),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.313674926757812),
(PoonDomingos, {"shape": [3, 3], "delta": 2}, None),
(QuadTree, {"width": 3, "height": 3, "struct_decomp": False}, None),
(RandomBinaryTree, {"num_vars": 9, "depth": 3, "num_repetitions": 2}, None),
Expand Down Expand Up @@ -223,4 +223,6 @@ def test_einet_partition_function(

assert torch.isclose(einet.partition_function(), sum_out, rtol=1e-6, atol=0)
if log_answer is not None:
assert torch.isclose(sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0)
assert torch.isclose(
sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0
), f"{sum_out.item()}"

0 comments on commit 63e9140

Please sign in to comment.