Skip to content

Commit

Permalink
Merge pull request #102 from april-tools/revert_90
Browse files Browse the repository at this point in the history
Revert #90
  • Loading branch information
lkct authored Jul 9, 2023
2 parents 5d842b3 + 2099016 commit 0c8a2bd
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 123 deletions.
2 changes: 1 addition & 1 deletion benchmark/cirkit/run_cirkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch import Tensor, optim
from torch.utils.data import DataLoader, TensorDataset

from cirkit.layers.einsum.cp import CPLayer # TODO: rework interfaces for import
from cirkit.layers.exp_family import CategoricalLayer
from cirkit.layers.sum_product.cp import CPLayer # TODO: rework interfaces for import
from cirkit.models import TensorizedPC
from cirkit.region_graph import RegionGraph
from cirkit.utils import RandomCtx, set_determinism
Expand Down
2 changes: 0 additions & 2 deletions cirkit/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from .layer import Layer as Layer
from .mixing import MixingLayer as MixingLayer
1 change: 1 addition & 0 deletions cirkit/layers/einsum/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .einsum import EinsumLayer as EinsumLayer
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import torch
from torch import Tensor, nn

from cirkit.layers.sum_product import SumProductLayer
from cirkit.region_graph import PartitionNode
from cirkit.utils import log_func_exp

from .einsum import EinsumLayer

# TODO: rework docstrings


class CPLayer(SumProductLayer):
class CPLayer(EinsumLayer):
"""Candecomp Parafac (decomposition) layer."""

# TODO: better way to call init by base class?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from torch import Tensor, nn

from cirkit.layers.layer import Layer
from cirkit.region_graph import PartitionNode

from ..layer import Layer

# TODO: relative import or absolute
# TODO: rework docstrings


class SumProductLayer(Layer):
"""Base for all "fused" sum-product layers."""
class EinsumLayer(Layer):
"""Base for all einsums."""

# TODO: kwargs should be public interface instead of `_`. How to supress this warning?
# all subclasses should accept all args as kwargs except for layer and k
Expand Down Expand Up @@ -47,13 +48,13 @@ def reset_parameters(self) -> None:
@abstractmethod
# pylint: disable-next=arguments-differ
def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignore[override]
"""Compute the main einsum operation of the layer.
"""Compute the main Einsum operation of the layer.
Do SumProductLayer forward pass.
Do EinsumLayer forward pass.
We assume that all parameters are in the correct range (no checks done).
Skeleton for each SumProductLayer (options Xa and Xb are mutual exclusive \
Skeleton for each EinsumLayer (options Xa and Xb are mutual exclusive \
and follows an a-path o b-path)
1) Go To exp-space (with maximum subtraction) -> NON SPECIFIC
2a) Do the einsum operation and go to the log space || 2b) Do the einsum operation
Expand Down
38 changes: 20 additions & 18 deletions cirkit/layers/mixing.py → cirkit/layers/einsum/mixing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List

import torch
from torch import Tensor, nn
Expand All @@ -10,7 +10,7 @@
# TODO: rework docstrings


class MixingLayer(Layer):
class EinsumMixingLayer(Layer):
# TODO: how we fold line here?
r"""Implement the Mixing Layer, in order to handle sum nodes with multiple children.
Expand Down Expand Up @@ -41,30 +41,27 @@ class MixingLayer(Layer):
The input nodes N have already been computed. The product nodes P and the \
first sum layer are computed using an
SumProductLayer, yielding a log-density tensor of shape
EinsumLayer, yielding a log-density tensor of shape
(batch_size, vector_length, num_nodes).
In this example num_nodes is 5, since the are 5 product nodes (or 5 singleton \
sum nodes). The MixingLayer
sum nodes). The EinsumMixingLayer
then simply mixes sums from the first layer, to yield 2 sums. This is just an \
over-parametrization of the original
excerpt.
"""

# TODO: might be good to doc params and buffers here
# to be registered as buffer
params_mask: Tensor

# TODO: num_output_units is num_input_units
def __init__(
self,
rg_nodes: List[RegionNode],
num_output_units: int,
max_components: int,
mask: Optional[Tensor] = None,
):
def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_components: int):
"""Init class.
Args:
rg_nodes (List[PartitionNode]): The region graph's partition node of the layer.
num_output_units (int): The number of output units.
max_components (int): Max number of mixing components.
mask (Optional[Tensor]): The mask to apply to the parameters.
"""
super().__init__()
self.fold_count = len(rg_nodes)
Expand All @@ -75,22 +72,24 @@ def __init__(
# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(num_output_units, len(rg_nodes), max_components))
self.mask = mask

# TODO: what's the use of params_mask?
self.register_buffer("params_mask", torch.ones_like(self.params))
self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset parameters to default initialization: U(0.01, 0.99) with normalization."""
nn.init.uniform_(self.params, 0.01, 0.99)

def apply_params_mask(self) -> None:
"""Apply the parameters mask."""
# TODO: What is this? Is it needed?
with torch.no_grad():
if self.mask is not None:
self.params *= self.mask # type: ignore[misc]
# TODO: assume mypy bug with __mul__ and __div__
self.params *= self.params_mask # type: ignore[misc]
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
if self.mask is not None:
torch.einsum("bonc,onc->bon", x, self.params * self.mask)
return torch.einsum("bonc,onc->bon", x, self.params)

# TODO: make forward return something
Expand All @@ -104,6 +103,9 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override]
Returns:
Tensor: the output.
"""
# TODO: use a mul or gather? or do we need this?
assert (self.params * self.params_mask == self.params).all()

return log_func_exp(log_input, func=self._forward_linear, dim=3, keepdim=False)

# TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack
15 changes: 11 additions & 4 deletions cirkit/layers/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class ExpFamilyLayer(Layer): # pylint: disable=too-many-instance-attributes
num_replica is picked large enough such that "we compute enough \
leaf densities". At the moment we rely that
the PC structure (see Class Graph) provides the necessary information \
to determine num_replica.
to determine num_replica. In
particular, we require that each leaf of the graph has the field \
einet_address.replica_idx defined;
num_replica is simply the max over all einet_address.replica_idx.
In the future, it would convenient to have an automatic allocation \
of leaves to replica, without requiring
the user to specify this.
Expand Down Expand Up @@ -62,7 +65,7 @@ def __init__(
self.num_stats = num_stats
self.fold_count = len(rg_nodes)

replica_indices = set(n.get_replica_idx() for n in self.rg_nodes)
replica_indices = set(n.einet_address.replica_idx for n in self.rg_nodes)
num_replica = len(replica_indices)
assert replica_indices == set(
range(num_replica)
Expand All @@ -73,7 +76,11 @@ def __init__(
# I have experimented a bit with this, but it is not always faster.
self.register_buffer("scope_tensor", torch.zeros(num_var, num_replica, len(self.rg_nodes)))
for i, node in enumerate(self.rg_nodes):
self.scope_tensor[list(node.scope), node.get_replica_idx(), i] = 1 # type: ignore[misc]
self.scope_tensor[
list(node.scope), node.einet_address.replica_idx, i # type: ignore[misc]
] = 1
node.einet_address.layer = self
node.einet_address.idx = i

self.params_shape = (num_var, num_units, num_replica, num_stats)
self.params = nn.Parameter(torch.empty(self.params_shape))
Expand Down Expand Up @@ -241,7 +248,7 @@ def backtrack( # type: ignore[misc]
assert len(dist_idx[n]) == len(node_idx[n]), "Invalid input."
for c, k in enumerate(node_idx[n]):
scope = list(self.rg_nodes[k].scope)
rep: int = self.rg_nodes[k].get_replica_idx()
rep = self.rg_nodes[k].einet_address.replica_idx
cur_value[scope, :] = (
ef_values[n, scope, :, dist_idx[n][c], rep]
if mode == "sample"
Expand Down
2 changes: 0 additions & 2 deletions cirkit/layers/sum_product/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion cirkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .tensorized_circuit import TensorizedPC as TensorizedPC
from .einet import TensorizedPC as TensorizedPC
84 changes: 38 additions & 46 deletions cirkit/models/tensorized_circuit.py → cirkit/models/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
from torch import Tensor, nn

from cirkit.layers.einsum import EinsumLayer
from cirkit.layers.einsum.mixing import EinsumMixingLayer
from cirkit.layers.exp_family import ExpFamilyLayer
from cirkit.layers.layer import Layer
from cirkit.layers.mixing import MixingLayer
from cirkit.layers.sum_product import SumProductLayer
from cirkit.region_graph import RegionGraph, RegionNode

# TODO: check all type casts. There should not be any without a good reason
Expand All @@ -29,7 +29,7 @@ def __init__( # type: ignore[misc]
self,
graph: RegionGraph,
num_vars: int,
layer_cls: Type[SumProductLayer],
layer_cls: Type[EinsumLayer],
efamily_cls: Type[ExpFamilyLayer],
layer_kwargs: Dict[str, Any],
efamily_kwargs: Dict[str, Any],
Expand All @@ -43,7 +43,7 @@ def __init__( # type: ignore[misc]
Args:
graph (RegionGraph): The region graph.
num_vars (int): The number of variables.
layer_cls (Type[SumProductLayer]): The inner layer class.
layer_cls (Type[EinsumLayer]): The inner layer class.
efamily_cls (Type[ExpFamilyLayer]): The exponential family class.
layer_kwargs (Dict[str, Any]): The parameters for the inner layer class.
efamily_kwargs (Dict[str, Any]): The parameters for the exponential family class.
Expand Down Expand Up @@ -86,17 +86,12 @@ def __init__( # type: ignore[misc]
**efamily_kwargs, # type: ignore[misc]
)

# A dictionary mapping each region node ID to
# (i) its index in the corresponding fold, and
# (ii) the layer that computes such fold.
region_id_fold: Dict[int, Tuple[int, Layer]] = {}
for i, region in enumerate(self.graph_layers[0][1]):
region_id_fold[region.get_id()] = (i, self.input_layer)

# Book-keeping: None for input, Tensor for mixing, Tuple for einsum
self.bookkeeping: List[
Union[
Tuple[Tuple[List[Layer], Tensor], Tuple[List[Layer], Tensor]], Tuple[Layer, Tensor]
Tuple[Tuple[List[Layer], Tensor], Tuple[List[Layer], Tensor]],
Tuple[Tensor, Tensor],
Tuple[None, None],
]
] = []

Expand Down Expand Up @@ -127,37 +122,37 @@ def __init__( # type: ignore[misc]
# TODO: again, why do we need sorting
# collect all layers which contain left/right children
# TODO: duplicate code
left_region_ids = list(r.left.get_id() for r in two_inputs)
right_region_ids = list(r.right.get_id() for r in two_inputs)
left_layers = list(region_id_fold[i][1] for i in left_region_ids)
right_layers = list(region_id_fold[i][1] for i in right_region_ids)
left_starts = torch.tensor([0] + [layer.fold_count for layer in left_layers]).cumsum(
dim=0
)
right_starts = torch.tensor([0] + [layer.fold_count for layer in right_layers]).cumsum(
left_layer = list(set(inputs.left.einet_address.layer for inputs in two_inputs))
left_starts = torch.tensor([0] + [layer.fold_count for layer in left_layer]).cumsum(
dim=0
)
left_indices = torch.tensor(
left_idx = torch.tensor(
[ # type: ignore[misc]
region_id_fold[r.left.get_id()][0] + left_starts[i]
for i, r in enumerate(two_inputs)
inputs.left.einet_address.idx
+ left_starts[left_layer.index(inputs.left.einet_address.layer)]
for inputs in two_inputs
]
)
right_indices = torch.tensor(
right_layer = list(set(inputs.right.einet_address.layer for inputs in two_inputs))
right_starts = torch.tensor([0] + [layer.fold_count for layer in right_layer]).cumsum(
dim=0
)
right_idx = torch.tensor(
[ # type: ignore[misc]
region_id_fold[r.right.get_id()][0] + right_starts[i]
for i, r in enumerate(two_inputs)
inputs.right.einet_address.idx
+ right_starts[right_layer.index(inputs.right.einet_address.layer)]
for inputs in two_inputs
]
)
self.bookkeeping.append(((left_layers, left_indices), (right_layers, right_indices)))
self.bookkeeping.append(((left_layer, left_idx), (right_layer, right_idx)))

# when the SumProductLayer is followed by a MixingLayer, we produce a
# when the EinsumLayer is followed by a EinsumMixingLayer, we produce a
# dummy "node" which outputs 0 (-inf in log-domain) for zero-padding.
dummy_idx: Optional[int] = None

# the dictionary mixing_component_idx stores which nodes (axis 2 of the
# log-density tensor) need to get mixed
# in the following MixingLayer
# in the following EinsumMixingLayer
mixing_component_idx: Dict[RegionNode, List[int]] = defaultdict(list)

for part_idx, partition in enumerate(partition_layer):
Expand All @@ -166,39 +161,36 @@ def __init__( # type: ignore[misc]
out_region = partition.outputs[0]

if len(out_region.inputs) == 1:
region_id_fold[out_region.get_id()] = (part_idx, inner_layer)
else: # case followed by MixingLayer
out_region.einet_address.layer = inner_layer
out_region.einet_address.idx = part_idx
else: # case followed by EinsumMixingLayer
mixing_component_idx[out_region].append(part_idx)
dummy_idx = len(partition_layer)

# The Mixing layer is only for regions which have multiple partitions as children.
if multi_sums := [region for region in region_layer if len(region.inputs) > 1]:
assert dummy_idx is not None
max_components = max(len(region.inputs) for region in multi_sums)
mixing_layer = EinsumMixingLayer(multi_sums, num_outputs, max_components)
inner_layers.append(mixing_layer)

# The following code does some bookkeeping.
# padded_idx indexes into the log-density tensor of the previous
# SumProductLayer, padded with a dummy input which
# outputs constantly 0 (-inf in the log-domain), see class SumProductLayer.
# EinsumLayer, padded with a dummy input which
# outputs constantly 0 (-inf in the log-domain), see class EinsumLayer.
padded_idx: List[List[int]] = []
params_mask: Optional[Tensor] = None
for reg_idx, region in enumerate(multi_sums):
num_components = len(mixing_component_idx[region])
this_idx = mixing_component_idx[region] + [dummy_idx] * (
max_components - num_components
)
padded_idx.append(this_idx)
if max_components > num_components:
if params_mask is None:
params_mask = torch.ones(num_outputs, len(multi_sums), max_components)
params_mask[:, reg_idx, num_components:] = 0.0
mixing_layer = MixingLayer(
multi_sums, num_outputs, max_components, mask=params_mask
)
for reg_idx, region in enumerate(multi_sums):
region_id_fold[region.get_id()] = (reg_idx, mixing_layer)
self.bookkeeping.append((inner_layers[-1], torch.tensor(padded_idx)))
inner_layers.append(mixing_layer)
mixing_layer.params_mask[:, reg_idx, num_components:] = 0.0
region.einet_address.layer = mixing_layer
region.einet_address.idx = reg_idx
mixing_layer.apply_params_mask()
self.bookkeeping.append((mixing_layer.params_mask, torch.tensor(padded_idx)))

# TODO: can we annotate a list here?
# TODO: actually we should not mix all the input/mix/ein different types in one list
Expand Down Expand Up @@ -287,7 +279,7 @@ def forward(self, x: Tensor) -> Tensor:
# TODO: use zip instead
# TODO: Generalize if statements here, they should be layer agnostic
for idx, inner_layer in enumerate(self.inner_layers):
if isinstance(inner_layer, SumProductLayer): # type: ignore[misc]
if isinstance(inner_layer, EinsumLayer): # type: ignore[misc]
left_addr, right_addr = self.bookkeeping[idx]
assert isinstance(left_addr, tuple) and isinstance(right_addr, tuple)
# TODO: we should use dim=2, check all code
Expand All @@ -297,7 +289,7 @@ def forward(self, x: Tensor) -> Tensor:
log_right_prob = torch.cat([outputs[layer] for layer in right_addr[0]], dim=2)
log_right_prob = log_right_prob[:, :, right_addr[1]]
out = inner_layer(log_left_prob, log_right_prob)
elif isinstance(inner_layer, MixingLayer):
elif isinstance(inner_layer, EinsumMixingLayer):
_, padded_idx = self.bookkeeping[idx]
assert isinstance(padded_idx, Tensor) # type: ignore[misc]
# TODO: a better way to pad?
Expand Down
4 changes: 4 additions & 0 deletions cirkit/region_graph/poon_domingos.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,8 @@ def PoonDomingos(
if found_cut_on_level:
break

# TODO: do we need this? already defaults to 0
# for node in get_leaves(graph):
# node.einet_address.replica_idx = 0

return graph
Loading

0 comments on commit 0c8a2bd

Please sign in to comment.