Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swap axes order -- data flow (C)FKB #103

Merged
merged 2 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cirkit/layers/einsum/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__( # type: ignore[misc]
super().__init__(rg_nodes, num_input_units, num_output_units)
self.prod_exp = prod_exp

self.params_left = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_right = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_out = nn.Parameter(torch.empty(num_output_units, rank, len(rg_nodes)))
self.params_left = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_right = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_out = nn.Parameter(torch.empty(len(rg_nodes), rank, num_output_units))

# TODO: get torch.default_float_dtype
# (float ** float) is not guaranteed to be float, but here we know it is
Expand All @@ -55,13 +55,13 @@ def __init__( # type: ignore[misc]

# TODO: use bmm to replace einsum? also axis order?
def _forward_left_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bip,irp->brp", x, self.params_left)
return torch.einsum("fkr,fkb->frb", self.params_left, x)

def _forward_right_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bip,irp->brp", x, self.params_right)
return torch.einsum("fkr,fkb->frb", self.params_right, x)

def _forward_out_linear(self, x: Tensor) -> Tensor:
return torch.einsum("brp,orp->bop", x, self.params_out)
return torch.einsum("frk,frb->fkb", self.params_out, x)

def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor:
left_hidden = self._forward_left_linear(left)
Expand Down
6 changes: 3 additions & 3 deletions cirkit/layers/einsum/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_compon

# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(num_output_units, len(rg_nodes), max_components))
self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units))
# TODO: what's the use of params_mask?
self.register_buffer("params_mask", torch.ones_like(self.params))
self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
Expand All @@ -90,7 +90,7 @@ def apply_params_mask(self) -> None:
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("bonc,onc->bon", x, self.params)
return torch.einsum("cfk,cfkb->fkb", self.params, x)

# TODO: make forward return something
# pylint: disable-next=arguments-differ
Expand All @@ -106,6 +106,6 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override]
# TODO: use a mul or gather? or do we need this?
assert (self.params * self.params_mask == self.params).all()

return log_func_exp(log_input, func=self._forward_linear, dim=3, keepdim=False)
return log_func_exp(log_input, func=self._forward_linear, dim=0, keepdim=False)

# TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack
6 changes: 4 additions & 2 deletions cirkit/layers/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
# I try with an einsum operation (x, o, d, s), (b, x, s) -> b, x, o, d.
# That should have the same result

crucial_quantity_einsum = torch.einsum("xods,bxs->bxod", theta, self.suff_stats)
crucial_quantity_einsum = torch.einsum("xkds,bxs->bxkd", theta, self.suff_stats)

# assert not torch.isnan(crucial_quantity_einsum).any()

Expand Down Expand Up @@ -205,7 +205,9 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
self.marginalization_mask = None
output = self.ll

return torch.einsum("bxir,xro->bio", output, self.scope_tensor)
# why bxkr instead of bxkd?
# TODO: the axes order for input layer? better remove this contiguous
return torch.einsum("bxkr,xrf->fkb", output, self.scope_tensor).contiguous()

# TODO: how to fix?
# pylint: disable-next=arguments-differ
Expand Down
14 changes: 7 additions & 7 deletions cirkit/models/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__( # type: ignore[misc]
region.einet_address.layer = mixing_layer
region.einet_address.idx = reg_idx
mixing_layer.apply_params_mask()
self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx)))
self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx).T))

# TODO: can we annotate a list here?
# TODO: actually we should not mix all the input/mix/ein different types in one list
Expand Down Expand Up @@ -284,10 +284,10 @@ def forward(self, x: Tensor) -> Tensor:
assert isinstance(left_addr, tuple) and isinstance(right_addr, tuple)
# TODO: we should use dim=2, check all code
# TODO: duplicate code
log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=2)
log_left_prob = log_left_prob[:, :, left_addr[1]]
log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=2)
log_right_prob = log_right_prob[:, :, right_addr[1]]
log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=0)
log_left_prob = log_left_prob[left_addr[1]]
log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=0)
log_right_prob = log_right_prob[right_addr[1]]
out = inner_layer(log_left_prob, log_right_prob)
elif isinstance(inner_layer, EinsumMixingLayer):
_, padded_idx = self.bookkeeping[idx]
Expand All @@ -298,13 +298,13 @@ def forward(self, x: Tensor) -> Tensor:
# outputs[self.inner_layers[idx - 1]] = F.pad(
# outputs[self.inner_layers[idx - 1]], [0, 1], "constant", float("-inf")
# )
log_input_prob = outputs[self.inner_layers[idx - 1]][:, :, padded_idx]
log_input_prob = outputs[self.inner_layers[idx - 1]][padded_idx]
out = inner_layer(log_input_prob)
else:
assert False
outputs[inner_layer] = out

return outputs[self.inner_layers[-1]][:, :, 0]
return outputs[self.inner_layers[-1]][0].T # return shape (B, K)

# TODO: and what's the meaning of this?
# def backtrack(self, num_samples=1, class_idx=0, x=None, mode='sampling', **kwargs):
Expand Down
38 changes: 20 additions & 18 deletions tests/models/test_einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def _get_einet() -> TensorizedPC:
def _get_param_shapes() -> Dict[str, Tuple[int, ...]]:
return {
"input_layer.params": (4, 1, 1, 2),
"inner_layers.0.params_left": (1, 1, 4),
"inner_layers.0.params_right": (1, 1, 4),
"inner_layers.0.params_out": (1, 1, 4),
"inner_layers.1.params_left": (1, 1, 2),
"inner_layers.1.params_right": (1, 1, 2),
"inner_layers.1.params_out": (1, 1, 2),
"inner_layers.2.params": (1, 1, 2),
"inner_layers.0.params_left": (4, 1, 1),
"inner_layers.0.params_right": (4, 1, 1),
"inner_layers.0.params_out": (4, 1, 1),
"inner_layers.1.params_left": (2, 1, 1),
"inner_layers.1.params_right": (2, 1, 1),
"inner_layers.1.params_out": (2, 1, 1),
"inner_layers.2.params": (2, 1, 1),
}


Expand All @@ -109,15 +109,15 @@ def _set_params(einet: TensorizedPC) -> None:
[math.log(3), 0], # type: ignore[misc] # 3/4, 1/4
]
).reshape(4, 1, 1, 2),
"inner_layers.0.params_left": torch.ones(1, 1, 4) / 2,
"inner_layers.0.params_right": torch.ones(1, 1, 4) * 2,
"inner_layers.0.params_out": torch.ones(1, 1, 4),
"inner_layers.1.params_left": torch.ones(1, 1, 2) * 2,
"inner_layers.1.params_right": torch.ones(1, 1, 2) / 2,
"inner_layers.1.params_out": torch.ones(1, 1, 2),
"inner_layers.0.params_left": torch.ones(4, 1, 1) / 2,
"inner_layers.0.params_right": torch.ones(4, 1, 1) * 2,
"inner_layers.0.params_out": torch.ones(4, 1, 1),
"inner_layers.1.params_left": torch.ones(2, 1, 1) * 2,
"inner_layers.1.params_right": torch.ones(2, 1, 1) / 2,
"inner_layers.1.params_out": torch.ones(2, 1, 1),
"inner_layers.2.params": torch.tensor(
[1 / 3, 2 / 3], # type: ignore[misc]
).reshape(1, 1, 2),
).reshape(2, 1, 1),
}
)
einet.load_state_dict(state_dict) # type: ignore[misc]
Expand Down Expand Up @@ -159,9 +159,9 @@ def test_einet_partition_func() -> None:
@pytest.mark.parametrize( # type: ignore[misc]
"rg_cls,kwargs,log_answer",
[
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.188161849975586),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 51.31766128540039),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.198360443115234),
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.935434341430664),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 44.412864685058594),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.313674926757812),
(PoonDomingos, {"shape": [3, 3], "delta": 2}, None),
(QuadTree, {"width": 3, "height": 3, "struct_decomp": False}, None),
(RandomBinaryTree, {"num_vars": 9, "depth": 3, "num_repetitions": 2}, None),
Expand Down Expand Up @@ -223,4 +223,6 @@ def test_einet_partition_function(

assert torch.isclose(einet.partition_function(), sum_out, rtol=1e-6, atol=0)
if log_answer is not None:
assert torch.isclose(sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0)
assert torch.isclose(
sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0
), f"{sum_out.item()}"