Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bookkeeping #105

Merged
merged 9 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cirkit/layers/exp_family/exp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
self.num_dims = num_dims
self.num_units = num_units
self.num_stats = num_stats
self.fold_count = len(rg_nodes)

replica_indices = set(n.get_replica_idx() for n in self.rg_nodes)
num_replica = len(replica_indices)
Expand Down
1 change: 0 additions & 1 deletion cirkit/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self) -> None:
"""Init class."""
super().__init__() # TODO: do we need multi-inherit init?
self.param_clamp_value: _ClampValue = {}
self.fold_count = 0

@abstractmethod
def reset_parameters(self) -> None:
Expand Down
27 changes: 8 additions & 19 deletions cirkit/layers/mixing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -51,46 +51,35 @@ class MixingLayer(Layer):
"""

# TODO: num_output_units is num_input_units
def __init__(
self,
rg_nodes: List[RegionNode],
num_output_units: int,
max_components: int,
mask: Optional[Tensor] = None,
):
def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_components: int):
"""Init class.

Args:
rg_nodes (List[PartitionNode]): The region graph's partition node of the layer.
num_output_units (int): The number of output units.
max_components (int): Max number of mixing components.
mask (Optional[Tensor]): The mask to apply to the parameters.
"""
super().__init__()
self.fold_count = len(rg_nodes)
self.rg_nodes = rg_nodes

# TODO: what need to be saved to self?
self.num_output_units = num_output_units

# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units))
self.mask = mask
self.params = nn.Parameter(torch.empty(len(rg_nodes), max_components, num_output_units))

self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset parameters to default initialization: U(0.01, 0.99) with normalization."""
nn.init.uniform_(self.params, 0.01, 0.99)
with torch.no_grad():
if self.mask is not None:
# TODO: assume mypy bug with __mul__ and __div_
self.params *= self.mask # type: ignore[misc]
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]
nn.init.uniform_(self.params, 0.01, 0.99)
self.params /= self.params.sum(dim=1, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("cfk,cfkb->fkb", self.params, x)
return torch.einsum("fck,fckb->fkb", self.params, x)
Copy link
Member

@lkct lkct Jul 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong about this


# TODO: make forward return something
# pylint: disable-next=arguments-differ
Expand All @@ -103,6 +92,6 @@ def forward(self, log_input: Tensor) -> Tensor: # type: ignore[override]
Returns:
Tensor: the output.
"""
return log_func_exp(log_input, func=self._forward_linear, dim=0, keepdim=False)
return log_func_exp(log_input, func=self._forward_linear, dim=1, keepdim=False)

# TODO: see commit 084a3685c6c39519e42c24a65d7eb0c1b0a1cab1 for backtrack
8 changes: 3 additions & 5 deletions cirkit/layers/sum_product/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__( # type: ignore[misc]
kwargs (Any): Passed to subclasses.
"""
super().__init__()
self.fold_count = len(rg_nodes)

self.rg_nodes = rg_nodes
self.num_input_units = num_input_units
self.num_output_units = num_output_units

Expand All @@ -46,7 +45,7 @@ def reset_parameters(self) -> None:
# TODO: what about abstract?
@abstractmethod
# pylint: disable-next=arguments-differ
def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignore[override]
def forward(self, inputs: Tensor) -> Tensor: # type: ignore[override]
"""Compute the main einsum operation of the layer.

Do SumProductLayer forward pass.
Expand All @@ -61,7 +60,6 @@ def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignor
4a) go to exp space do the einsum and back to log || 4b) do the einsum operation [OPT]
5a) do nothing || 5b) back to log space

:param log_left: value in log space for left child.
:param log_right: value in log space for right child.
:param inputs: the input tensor.
:return: result of the left operations, in log-space.
"""
7 changes: 4 additions & 3 deletions cirkit/layers/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor:
right_hidden = self._forward_right_linear(right)
return self._forward_out_linear(left_hidden * right_hidden)

def forward(self, log_left: Tensor, log_right: Tensor) -> Tensor: # type: ignore[override]
def forward(self, inputs: Tensor) -> Tensor: # type: ignore[override]
"""Compute the main Einsum operation of the layer.

:param log_left: value in log space for left child.
:param log_right: value in log space for right child.
:param inputs: value in log space for left child.
:return: result of the left operations, in log-space.
"""
log_left, log_right = inputs[:, 0], inputs[:, 1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two are non-contiguous inputs for the following.

however seems no major performance impact for CP? yet we don't know if this will harm other layers, so if no performance difference, we should use contiguous.


# TODO: do we split into two impls?
if self.prod_exp:
return log_func_exp(log_left, log_right, func=self._forward_linear, dim=1, keepdim=True)
Expand Down
Loading