Skip to content

Commit

Permalink
Merge pull request #213 from yhtang/work/bug-duplicate-factors
Browse files Browse the repository at this point in the history
Factorization object only return unique factors. Fixes #210.
  • Loading branch information
yhtang authored Feb 18, 2022
2 parents d74eb46 + bfcda5b commit 21b2f1b
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 203 deletions.
26 changes: 13 additions & 13 deletions docs/examples/matrix-approximation.ipynb

Large diffs are not rendered by default.

72 changes: 22 additions & 50 deletions docs/examples/quantum-compilation.ipynb

Large diffs are not rendered by default.

48 changes: 29 additions & 19 deletions funfact/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import funfact.loss
from funfact import Factorization
from funfact.backend import active_backend as ab
from funfact.vectorization import vectorize, view
from funfact.vectorization import view


def factorize(
tsrex, target, optimizer='Adam', loss='MSE', lr=0.1, tol=1e-6,
max_steps=10000, nvec=1, append=False, stop_by='first', returns='best',
max_steps=10000, vec_size=1, vec_axis=0, stop_by='first', returns='best',
checkpoint_freq=50, dtype=None, penalty_weight=1.0
):
'''Factorize a target tensor using the given tensor expression. The
Expand All @@ -38,8 +38,8 @@ def factorize(
lr (float): SGD learning rate.
tol (float): convergence tolerance.
max_steps (int): maximum number of SGD steps to run.
nvec (int): Number of parallel instances to compute.
append (bool): If vectorizing axis is appended or prepended.
vec_size (int): Number of parallel instances to compute.
vec_axis (0 or -1): The position of the vectorization dimension.
stop_by ('first', int >= 1, or None):
- If 'first', stop optimization as soon as one solution is
Expand Down Expand Up @@ -75,16 +75,19 @@ def factorize(
that represents all the solutions.
'''

tsrex_vec = vectorize(tsrex, nvec, append=append)

fac = ab.add_autograd(Factorization).from_tsrex(tsrex_vec, dtype=dtype)
assert vec_axis in [0, -1], "Vectorization axis must be either 0 or -1."
append = True if vec_axis == -1 else False

if dtype is None:
target = ab.tensor(target)
dtype = target.dtype
else:
target = ab.tensor(target, dtype=dtype)

fac = ab.add_autograd(Factorization).from_tsrex(
tsrex, dtype=dtype, vec_size=vec_size, vec_axis=vec_axis
)

if isinstance(optimizer, str):
try:
optimizer = getattr(funfact.optim, optimizer)
Expand Down Expand Up @@ -143,8 +146,8 @@ def loss_and_penalty(model, target, sum_vec=True):

# bookkeeping
best_factors = [np.zeros_like(ab.to_numpy(x)) for x in fac.factors]
best_loss = np.ones(nvec) * np.inf
converged = np.zeros(nvec, dtype=np.bool_)
best_loss = np.ones(vec_size) * np.inf
converged = np.zeros(vec_size, dtype=np.bool_)

for step in tqdm.trange(max_steps):
_, grad = loss_and_grad(fac, target)
Expand All @@ -169,16 +172,23 @@ def loss_and_penalty(model, target, sum_vec=True):
if np.count_nonzero(converged) >= stop_by:
break

best_fac = Factorization.from_tsrex(tsrex_vec, dtype=dtype)
best_fac.factors = [ab.tensor(x) for x in best_factors]
best_factors = [ab.tensor(x) for x in best_factors]

if returns == 'best':
return view(best_fac, tsrex, np.argmin(best_loss), append)
elif isinstance(returns, int):
return [
view(best_fac, tsrex, i, append) for i in
np.argsort(best_loss)[:returns]
]
elif returns == 'all':
return view(
best_factors,
Factorization.from_tsrex(tsrex, dtype=dtype),
np.argmin(best_loss), append
)
else:
if isinstance(returns, int):
instances = np.argsort(best_loss)[:returns]
elif returns == 'all':
instances = np.argsort(best_loss)
return [
view(best_fac, tsrex, i, append) for i in np.argsort(best_loss)
view(
best_factors,
Factorization.from_tsrex(tsrex, dtype=dtype),
i, append
) for i in instances
]
4 changes: 1 addition & 3 deletions funfact/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def tree_flatten(self):

@classmethod
def tree_unflatten(cls, metadata, children):
unflatten = cls(*metadata, initialize=False)
unflatten.factors = children
return unflatten
return cls._from_jax_flatten(*metadata, children)

return register_pytree_node_class(AddAutoGrad)

Expand Down
81 changes: 58 additions & 23 deletions funfact/model/_factorization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from numbers import Integral
from funfact import active_backend as ab
from funfact.lang.interpreter import (
dfs_filter,
TypeDeducer,
Expand All @@ -11,7 +12,8 @@
ElementwiseEvaluator,
SlicingPropagator,
)
from funfact import active_backend as ab
from funfact.util.iterable import unique
from funfact.vectorization import vectorize


class Factorization:
Expand All @@ -36,26 +38,48 @@ class Factorization:
<funfact.model._factorization.Factorization object at 0x7f5838105ee0>
'''

def __init__(self, tsrex, **extra_attributes):
self._tsrex = (tsrex
| IndexnessAnalyzer()
| TypeDeducer()
| EinopCompiler())
def __init__(self, tsrex, _secret=None, **extra_attributes):
if _secret != '50A-2117':
raise RuntimeError(
'Please use one of the `from_*` methods to create a '
'factorization from a tensor expression'
)
self.tsrex = tsrex
self.__dict__.update(**extra_attributes)

@classmethod
def from_tsrex(cls, tsrex, dtype=None, initialize=True):
def from_tsrex(
cls, tsrex, dtype=None, vec_size=None, vec_axis=0, initialize=True
):
'''Construct a factorization model from a tensor expresson.
Args:
tsrex (TsrEx): The tensor expression.
dtype: numerical data type, defaults to float32.
vec_size (int):
Whether to vectorize the tensor expression with parallel
instances.
vec_axis (0 or -1): The position of the vectorization dimension.
initialize (bool):
Whether or not to fill abstract tensors with actual data.
'''
if vec_size:
tsrex = vectorize(
tsrex, vec_size, append=True if vec_axis == -1 else False
)
tsrex = tsrex | IndexnessAnalyzer() | TypeDeducer() | EinopCompiler()
if initialize:
tsrex = tsrex | LeafInitializer(dtype)
return cls(tsrex)
return cls(tsrex, _secret='50A-2117')

@classmethod
def _from_jax_flatten(cls, tsrex, factors):
'''
'''
tsrex = tsrex | IndexnessAnalyzer()
fac = cls(tsrex, _secret='50A-2117')
fac.factors = factors
return fac

@property
def factors(self):
Expand All @@ -80,16 +104,18 @@ def factors(self):
'''
return self._NodeView(
'data',
list(dfs_filter(lambda n: n.name == 'tensor' and
n.decl.optimizable, self.tsrex.root))
list(unique(dfs_filter(
lambda n: n.name == 'tensor' and n.decl.optimizable,
self.tsrex.root
)))
)

@factors.setter
def factors(self, tensors):
for i, n in enumerate(
for i, n in enumerate(unique(
dfs_filter(lambda n: n.name == 'tensor' and
n.decl.optimizable, self.tsrex.root)
):
)):
n.data = tensors[i]

@property
Expand Down Expand Up @@ -117,14 +143,21 @@ def all_factors(self):
'''
return self._NodeView(
'data',
list(dfs_filter(lambda n: n.name == 'tensor', self.tsrex.root))
list(unique(dfs_filter(
lambda n: n.name == 'tensor', self.tsrex.root
)))
)

@property
def tsrex(self):
'''The underlying tensor expression.'''
return self._tsrex

@tsrex.setter
def tsrex(self, tsrex):
'''Setting the underlying tensor expression.'''
self._tsrex = tsrex

@property
def shape(self):
'''The shape of the result tensor.'''
Expand All @@ -135,24 +168,26 @@ def ndim(self):
'''The dimensionality of the result tensor.'''
return self.tsrex.ndim

def penalty(self, sum_leafs: bool = True, sum_vec=None):
def penalty(self, sum_leafs: bool = True, sum_vec=False):
'''The penalty of the result tensor.
Args:
sum_leafs (bool): sum the penalties over the leafs of the model.
sum_vec (bool): sum the penalties over the vectorization dimension.
'''

factors = list(dfs_filter(
factors = list(unique(dfs_filter(
lambda n: n.name == 'tensor' and n.decl.optimizable,
self.tsrex.root)
)
self.tsrex.root
)))
penalties = ab.stack(
[f.decl.prefer(f.data, sum_vec) for f in factors],
0 if sum_vec else -1
)
return ab.sum(penalties, 0 if sum_vec else -1) if sum_leafs else \
penalties
if sum_leafs:
return ab.sum(penalties, 0 if sum_vec else -1)
else:
return penalties

def __call__(self):
'''Shorthand for :py:meth:`forward`.'''
Expand Down Expand Up @@ -211,21 +246,21 @@ def __getitem__(self, idx):
'''Implements attribute-based access of factor tensors or output
elements.'''
if isinstance(idx, str):
for n in dfs_filter(
for n in unique(dfs_filter(
lambda n: n.name == 'tensor' and str(n.decl.symbol) == idx,
self.tsrex.root
):
)):
return n.data
raise AttributeError(f'No factor tensor named {idx}.')
else:
return self._get_elements(idx)

def __setitem__(self, name, data):
'''Implements attribute-based access of factor tensors.'''
for n in dfs_filter(
for n in unique(dfs_filter(
lambda n: n.name == 'tensor' and str(n.decl.symbol) == name,
self.tsrex.root
):
)):
return setattr(n, 'data', data)
raise AttributeError(f'No factor tensor named {name}.')

Expand Down
28 changes: 24 additions & 4 deletions funfact/model/test_factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from unittest.mock import MagicMock as M
from funfact import active_backend as ab
from funfact.lang._ast import Primitives as P
from ._factorization import Factorization


Expand All @@ -21,7 +22,7 @@ def _test_factorization_props(fac):


def test_init_factory():
fac = Factorization(M())
fac = Factorization(M(), _secret='50A-2117')
_test_factorization_props(fac)
fac = Factorization.from_tsrex(M(), dtype=M(), initialize=M())
_test_factorization_props(fac)
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_as_slice(test_case):
def _gen_factors_mock(data, optimizable=True):
root = M(data=data, decl=M(optimizable=optimizable))
root.name = 'tensor'
return Factorization(M(), _tsrex=M(root=root))
return Factorization(M(root=root), _secret='50A-2117')


def test_factors():
Expand All @@ -85,14 +86,14 @@ def _prefer(data, *args):
root = M(data=ab.tensor([1, 2, 3]),
decl=M(optimizable=True, prefer=_prefer))
root.name = 'tensor'
fac = Factorization(M(), _tsrex=M(root=root))
fac = Factorization(M(root=root), _secret='50A-2117')
assert fac.penalty() == 6


def test_get_set_item():
root = M(data=1, decl=M(symbol='a', optimizable=True), ndim=1)
root.name = 'tensor'
fac = Factorization(M(), _tsrex=M(root=root))
fac = Factorization(M(root=root), _secret='50A-2117')
assert fac['a'] == 1
fac['a'] = 2
assert fac['a'] == 2
Expand All @@ -104,3 +105,22 @@ def test_get_set_item():
fac[0, 0]
with pytest.raises(IndexError):
fac[0, ...]


def test_duplicate_factors():
a = P.tensor(decl=M(symbol='a', optimizable=True, prefer=lambda *_: 0))
b = P.tensor(decl=M(symbol='b', optimizable=True, prefer=lambda *_: 0))

fac1 = Factorization(M(root=P.elem(a, a, 0, 'add')), _secret='50A-2117')
fac2 = Factorization(M(root=P.elem(a, b, 0, 'add')), _secret='50A-2117')

assert len(fac1.factors) == 1
assert len(fac2.factors) == 2

fac1.factors = [None]
with pytest.raises(IndexError):
fac2.factors = [None]
fac2.factors = [None, None]

assert len(fac1.penalty(sum_leafs=False)) == 1
assert len(fac2.penalty(sum_leafs=False)) == 2
12 changes: 6 additions & 6 deletions funfact/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def __call__(self, only_one):


@pytest.mark.parametrize('stop_by', ['first', 2, None])
@pytest.mark.parametrize('append', [True, False])
def test_kwargs(stop_by, append):
@pytest.mark.parametrize('vec_axis', [0, -1])
def test_kwargs(stop_by, vec_axis):

fac = factorize(
tensor(2), ab.ones(2), nvec=4, stop_by=stop_by, append=append,
tensor(2), ab.ones(2), vec_size=4, stop_by=stop_by, vec_axis=vec_axis,
max_steps=100
)

Expand All @@ -87,17 +87,17 @@ def test_kwargs(stop_by, append):
def test_returns():

fac = factorize(
tensor(2), ab.ones(2), nvec=4, max_steps=100, returns='best'
tensor(2), ab.ones(2), vec_size=4, max_steps=100, returns='best'
)
assert not isinstance(fac, list)

fac = factorize(
tensor(2), ab.ones(2), nvec=4, max_steps=100, returns=2
tensor(2), ab.ones(2), vec_size=4, max_steps=100, returns=2
)
assert isinstance(fac, list)

fac = factorize(
tensor(2), ab.ones(2), nvec=4, max_steps=100, returns='all'
tensor(2), ab.ones(2), vec_size=4, max_steps=100, returns='all'
)
assert isinstance(fac, list)

Expand Down
Loading

0 comments on commit 21b2f1b

Please sign in to comment.