Skip to content

Commit

Permalink
Merge pull request #105 from april-tools/refactor-bookkeeping
Browse files Browse the repository at this point in the history
Refactor bookkeeping
  • Loading branch information
loreloc authored Jul 11, 2023
2 parents 5445dbe + 1ed21c4 commit 16ace43
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 189 deletions.
1 change: 0 additions & 1 deletion cirkit/layers/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
self.num_dims = num_dims
self.num_units = num_units
self.num_stats = num_stats
self.fold_count = len(rg_nodes)

replica_indices = set(n.get_replica_idx() for n in self.rg_nodes)
num_replica = len(replica_indices)
Expand Down
1 change: 0 additions & 1 deletion cirkit/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self) -> None:
"""Init class."""
super().__init__() # TODO: do we need multi-inherit init?
self.param_clamp_value: _ClampValue = {}
self.fold_count = 0

@abstractmethod
def reset_parameters(self) -> None:
Expand Down
27 changes: 8 additions & 19 deletions cirkit/layers/mixing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -51,46 +51,35 @@ class MixingLayer(Layer):
"""

# TODO: num_output_units is num_input_units
def __init__(
self,
rg_nodes: List[RegionNode],
num_output_units: int,
max_components: int,
mask: Optional[Tensor] = None,
):
def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_components: int):
"""Init class.
Args:
rg_nodes (List[PartitionNode]): The region graph's partition node of the layer.
num_output_units (int): The number of output units.
max_components (int): Max number of mixing components.
mask (Optional[Tensor]): The mask to apply to the parameters.
"""
super().__init__()
self.fold_count = len(rg_nodes)
self.rg_nodes = rg_nodes

# TODO: what need to be saved to self?
self.num_output_units = num_output_units

# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units))
self.mask = mask
self.params = nn.Parameter(torch.empty(len(rg_nodes), max_components, num_output_units))

self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset parameters to default initialization: U(0.01, 0.99) with normalization."""
nn.init.uniform_(self.params, 0.01, 0.99)
with torch.no_grad():
if self.mask is not None:
# TODO: assume mypy bug with __mul__ and __div_
self.params *= self.mask # type: ignore[misc]
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]
nn.init.uniform_(self.params, 0.01, 0.99)
self.params /= self.params.sum(dim=1, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("cfk,cfkb->fkb", self.params, x)
return torch.einsum("fck,fckb->fkb", self.params, x)

# TODO: make forward return something
# pylint: disable-next=arguments-differ
Expand All @@ -103,6 +92,6 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override]
Returns:
Tensor: the output.
"""
return log_func_exp(log_input, func=self._forward_linear, dim=0, keepdim=False)
return log_func_exp(log_input, func=self._forward_linear, dim=1, keepdim=False)

# TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack
8 changes: 3 additions & 5 deletions cirkit/layers/sum_product/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__( # type: ignore[misc]
kwargs (Any): Passed to subclasses.
"""
super().__init__()
self.fold_count = len(rg_nodes)

self.rg_nodes = rg_nodes
self.num_input_units = num_input_units
self.num_output_units = num_output_units

Expand All @@ -46,7 +45,7 @@ def reset_parameters(self) -> None:
# TODO: what about abstract?
@abstractmethod
# pylint: disable-next=arguments-differ
def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignore[override]
def forward(self, inputs: Tensor) -> Tensor: # type: ignore[override]
"""Compute the main einsum operation of the layer.
Do SumProductLayer forward pass.
Expand All @@ -61,7 +60,6 @@ def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignor
4a) go to exp space do the einsum and back to log || 4b) do the einsum operation [OPT]
5a) do nothing || 5b) back to log space
:param log_left: value in log space for left child.
:param log_right: value in log space for right child.
:param inputs: the input tensor.
:return: result of the left operations, in log-space.
"""
7 changes: 4 additions & 3 deletions cirkit/layers/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor:
right_hidden = self._forward_right_linear(right)
return self._forward_out_linear(left_hidden * right_hidden)

def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignore[override]
def forward(self, inputs: Tensor) -> Tensor: # type: ignore[override]
"""Compute the main Einsum operation of the layer.
:param log_left: value in log space for left child.
:param log_right: value in log space for right child.
:param inputs: value in log space for left child.
:return: result of the left operations, in log-space.
"""
log_left, log_right = inputs[:, 0], inputs[:, 1]

# TODO: do we split into two impls?
if self.prod_exp:
return log_func_exp(log_left, log_right, func=self._forward_linear, dim=1, keepdim=True)
Expand Down
Loading

0 comments on commit 16ace43

Please sign in to comment.