diff --git a/cirkit/layers/einsum/cp.py b/cirkit/layers/einsum/cp.py index 61effcfd..e7c3a4c3 100644 --- a/cirkit/layers/einsum/cp.py +++ b/cirkit/layers/einsum/cp.py @@ -39,9 +39,9 @@ def __init__( # type: ignore[misc] super().__init__(rg_nodes, num_input_units, num_output_units) self.prod_exp = prod_exp - self.params_left = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes))) - self.params_right = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes))) - self.params_out = nn.Parameter(torch.empty(num_output_units, rank, len(rg_nodes))) + self.params_left = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank)) + self.params_right = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank)) + self.params_out = nn.Parameter(torch.empty(len(rg_nodes), rank, num_output_units)) # TODO: get torch.default_float_dtype # (float ** float) is not guaranteed to be float, but here we know it is @@ -55,13 +55,13 @@ def __init__( # type: ignore[misc] # TODO: use bmm to replace einsum? also axis order? def _forward_left_linear(self, x: Tensor) -> Tensor: - return torch.einsum("bip,irp->brp", x, self.params_left) + return torch.einsum("fkr,fkb->frb", self.params_left, x) def _forward_right_linear(self, x: Tensor) -> Tensor: - return torch.einsum("bip,irp->brp", x, self.params_right) + return torch.einsum("fkr,fkb->frb", self.params_right, x) def _forward_out_linear(self, x: Tensor) -> Tensor: - return torch.einsum("brp,orp->bop", x, self.params_out) + return torch.einsum("frk,frb->fkb", self.params_out, x) def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor: left_hidden = self._forward_left_linear(left) diff --git a/cirkit/layers/einsum/mixing.py b/cirkit/layers/einsum/mixing.py index 04c61e82..e52fc5a2 100644 --- a/cirkit/layers/einsum/mixing.py +++ b/cirkit/layers/einsum/mixing.py @@ -71,7 +71,7 @@ def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_compon # TODO: test best perf? # param_shape = (len(self.nodes), self.max_components) for better perf - self.params = nn.Parameter(torch.empty(num_output_units, len(rg_nodes), max_components)) + self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units)) # TODO: what's the use of params_mask? self.register_buffer("params_mask", torch.ones_like(self.params)) self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal @@ -90,7 +90,7 @@ def apply_params_mask(self) -> None: self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc] def _forward_linear(self, x: Tensor) -> Tensor: - return torch.einsum("bonc,onc->bon", x, self.params) + return torch.einsum("cfk,cfkb->fkb", self.params, x) # TODO: make forward return something # pylint: disable-next=arguments-differ @@ -106,6 +106,6 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override] # TODO: use a mul or gather? or do we need this? assert (self.params * self.params_mask == self.params).all() - return log_func_exp(log_input, func=self._forward_linear, dim=3, keepdim=False) + return log_func_exp(log_input, func=self._forward_linear, dim=0, keepdim=False) # TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack diff --git a/cirkit/layers/exp_family/exp_family.py b/cirkit/layers/exp_family/exp_family.py index cafb0acb..7b38ba5f 100644 --- a/cirkit/layers/exp_family/exp_family.py +++ b/cirkit/layers/exp_family/exp_family.py @@ -170,7 +170,7 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override] # I try with an einsum operation (x, o, d, s), (b, x, s) -> b, x, o, d. # That should have the same result - crucial_quantity_einsum = torch.einsum("xods,bxs->bxod", theta, self.suff_stats) + crucial_quantity_einsum = torch.einsum("xkds,bxs->bxkd", theta, self.suff_stats) # assert not torch.isnan(crucial_quantity_einsum).any() @@ -205,7 +205,9 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore[override] self.marginalization_mask = None output = self.ll - return torch.einsum("bxir,xro->bio", output, self.scope_tensor) + # why bxkr instead of bxkd? + # TODO: the axes order for input layer? better remove this contiguous + return torch.einsum("bxkr,xrf->fkb", output, self.scope_tensor).contiguous() # TODO: how to fix? # pylint: disable-next=arguments-differ diff --git a/cirkit/models/einet.py b/cirkit/models/einet.py index d81c2c1a..4c44c186 100644 --- a/cirkit/models/einet.py +++ b/cirkit/models/einet.py @@ -190,7 +190,7 @@ def __init__( # type: ignore[misc] region.einet_address.layer = mixing_layer region.einet_address.idx = reg_idx mixing_layer.apply_params_mask() - self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx))) + self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx).T)) # TODO: can we annotate a list here? # TODO: actually we should not mix all the input/mix/ein different types in one list @@ -284,10 +284,10 @@ def forward(self, x: Tensor) -> Tensor: assert isinstance(left_addr, tuple) and isinstance(right_addr, tuple) # TODO: we should use dim=2, check all code # TODO: duplicate code - log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=2) - log_left_prob = log_left_prob[:, :, left_addr[1]] - log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=2) - log_right_prob = log_right_prob[:, :, right_addr[1]] + log_left_prob = torch.cat([outputs[layer] for layer in left_addr[0]], dim=0) + log_left_prob = log_left_prob[left_addr[1]] + log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=0) + log_right_prob = log_right_prob[right_addr[1]] out = inner_layer(log_left_prob, log_right_prob) elif isinstance(inner_layer, EinsumMixingLayer): _, padded_idx = self.bookkeeping[idx] @@ -298,13 +298,13 @@ def forward(self, x: Tensor) -> Tensor: # outputs[self.inner_layers[idx - 1]] = F.pad( # outputs[self.inner_layers[idx - 1]], [0, 1], "constant", float("-inf") # ) - log_input_prob = outputs[self.inner_layers[idx - 1]][:, :, padded_idx] + log_input_prob = outputs[self.inner_layers[idx - 1]][padded_idx] out = inner_layer(log_input_prob) else: assert False outputs[inner_layer] = out - return outputs[self.inner_layers[-1]][:, :, 0] + return outputs[self.inner_layers[-1]][0].T # return shape (B, K) # TODO: and what's the meaning of this? # def backtrack(self, num_samples=1, class_idx=0, x=None, mode='sampling', **kwargs): diff --git a/tests/models/test_einet.py b/tests/models/test_einet.py index 001c11af..4b13fe83 100644 --- a/tests/models/test_einet.py +++ b/tests/models/test_einet.py @@ -86,13 +86,13 @@ def _get_einet() -> TensorizedPC: def _get_param_shapes() -> Dict[str, Tuple[int, ...]]: return { "input_layer.params": (4, 1, 1, 2), - "inner_layers.0.params_left": (1, 1, 4), - "inner_layers.0.params_right": (1, 1, 4), - "inner_layers.0.params_out": (1, 1, 4), - "inner_layers.1.params_left": (1, 1, 2), - "inner_layers.1.params_right": (1, 1, 2), - "inner_layers.1.params_out": (1, 1, 2), - "inner_layers.2.params": (1, 1, 2), + "inner_layers.0.params_left": (4, 1, 1), + "inner_layers.0.params_right": (4, 1, 1), + "inner_layers.0.params_out": (4, 1, 1), + "inner_layers.1.params_left": (2, 1, 1), + "inner_layers.1.params_right": (2, 1, 1), + "inner_layers.1.params_out": (2, 1, 1), + "inner_layers.2.params": (2, 1, 1), } @@ -109,15 +109,15 @@ def _set_params(einet: TensorizedPC) -> None: [math.log(3), 0], # type: ignore[misc] # 3/4, 1/4 ] ).reshape(4, 1, 1, 2), - "inner_layers.0.params_left": torch.ones(1, 1, 4) / 2, - "inner_layers.0.params_right": torch.ones(1, 1, 4) * 2, - "inner_layers.0.params_out": torch.ones(1, 1, 4), - "inner_layers.1.params_left": torch.ones(1, 1, 2) * 2, - "inner_layers.1.params_right": torch.ones(1, 1, 2) / 2, - "inner_layers.1.params_out": torch.ones(1, 1, 2), + "inner_layers.0.params_left": torch.ones(4, 1, 1) / 2, + "inner_layers.0.params_right": torch.ones(4, 1, 1) * 2, + "inner_layers.0.params_out": torch.ones(4, 1, 1), + "inner_layers.1.params_left": torch.ones(2, 1, 1) * 2, + "inner_layers.1.params_right": torch.ones(2, 1, 1) / 2, + "inner_layers.1.params_out": torch.ones(2, 1, 1), "inner_layers.2.params": torch.tensor( [1 / 3, 2 / 3], # type: ignore[misc] - ).reshape(1, 1, 2), + ).reshape(2, 1, 1), } ) einet.load_state_dict(state_dict) # type: ignore[misc] @@ -159,9 +159,9 @@ def test_einet_partition_func() -> None: @pytest.mark.parametrize( # type: ignore[misc] "rg_cls,kwargs,log_answer", [ - (PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.188161849975586), - (QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 51.31766128540039), - (RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.198360443115234), + (PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.935434341430664), + (QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 44.412864685058594), + (RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.313674926757812), (PoonDomingos, {"shape": [3, 3], "delta": 2}, None), (QuadTree, {"width": 3, "height": 3, "struct_decomp": False}, None), (RandomBinaryTree, {"num_vars": 9, "depth": 3, "num_repetitions": 2}, None), @@ -223,4 +223,6 @@ def test_einet_partition_function( assert torch.isclose(einet.partition_function(), sum_out, rtol=1e-6, atol=0) if log_answer is not None: - assert torch.isclose(sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0) + assert torch.isclose( + sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0 + ), f"{sum_out.item()}"