Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

Implemented Cholesky decomposition to numpy and tensorflow backends #890

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
016d654
Added power to Jax backend and test
LuiGiovanni Dec 5, 2020
62a06c0
Added power function to the Jax backend and test.
LuiGiovanni Dec 5, 2020
43952dd
Fixed white space error.
LuiGiovanni Dec 5, 2020
b0f579e
Re-made the power test function.
LuiGiovanni Dec 5, 2020
e9f9869
Fixed typo
LuiGiovanni Dec 5, 2020
9e16a9f
Testing out numpy square.
LuiGiovanni Dec 5, 2020
81b4d8b
Fixed issues with the assertion in test, should work now.
LuiGiovanni Dec 5, 2020
195f691
Added NotImplementedError function and it's respective test for Chols…
LuiGiovanni Dec 6, 2020
737f071
Fixed line too long error
LuiGiovanni Dec 6, 2020
0fbe8fb
Removing changes for different branches.
LuiGiovanni Dec 6, 2020
8c37656
Revert "Fixed issues with the assertion in test, should work now."
LuiGiovanni Dec 6, 2020
a894960
Fixed a few pylint issues, should work now
LuiGiovanni Dec 6, 2020
3ba747f
Fixed typo in the error message
LuiGiovanni Dec 6, 2020
a112ebe
Merge branch 'master' into CholeskyDecomposition
alewis Dec 7, 2020
f9fb40d
Renamed cholesky function, to a shorter name
LuiGiovanni Dec 7, 2020
4b4ece6
Merge branch 'CholeskyDecomposition' of https://github.com/LuiGiovann…
LuiGiovanni Dec 7, 2020
0c3817d
Renamed cholesky function from cholesky_decomposition to cholesky
LuiGiovanni Dec 7, 2020
aedfe2a
Implemented Cholesky to numpy and tensorflow with their respective tests
LuiGiovanni Dec 23, 2020
13026b5
Fixed pylint issues
LuiGiovanni Dec 23, 2020
11eba6a
Implemented Cholesky decomposition to tensorflow, numpy & pytorch bac…
LuiGiovanni Dec 25, 2020
4da6eac
Fixed minor pylint issues
LuiGiovanni Dec 25, 2020
6548989
Merge branch 'master' into CholeskyDecomposition
mganahl Dec 28, 2020
66b5b94
Merge branch 'master' into CholeskyDecomposition
mganahl Dec 28, 2020
dab2060
requested changes made removed unnecessary arguments and better matrix
LuiGiovanni Dec 28, 2020
e6b9c08
Merge branch 'CholeskyDecomposition' of https://github.com/LuiGiovann…
LuiGiovanni Dec 28, 2020
367ce3d
Fixed commented testing functions
LuiGiovanni Dec 28, 2020
92c931a
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 4, 2021
a51cb9a
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 5, 2021
2b269e0
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tensornetwork/backends/numpy/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,22 @@ def rq(
r = np.reshape(r, list(left_dims) + [center_dim])
q = np.reshape(q, [center_dim] + list(right_dims))
return r, q


def cholesky(
np: Any,
tensor: Tensor,
pivot_axis: int
) -> Tuple[Tensor, Tensor]:
"""
Computes the Cholesky decomposition of a tensor

See tensornetwork.backends.tensorflow.decompositions for details.
"""
left_dims = np.shape(tensor)[:pivot_axis]
right_dims = np.shape(tensor)[pivot_axis:]
tensor = np.reshape(tensor,
[np.reduce_prod(left_dims),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np is the numpy module. numpy does not have reduce_prod, that's a tensorflow routine. This only passes the test because in decompositions_test.py you are erroneously passing the tensorflow module instead of the numpy module to numpy.decompositions.cholesky. pls fix this

np.reduce_prod(right_dims)])
L = np.linalg.cholesky(tensor)
return L
9 changes: 9 additions & 0 deletions tensornetwork/backends/numpy/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import numpy as np
import tensorflow as tf
from tensornetwork.backends.numpy import decompositions
import pytest

np_dtypes = [np.float64, np.complex128]

class DecompositionsTest(tf.test.TestCase):

Expand Down Expand Up @@ -52,6 +54,13 @@ def test_qr(self):
q, r = decompositions.qr(np, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(q.dot(r), random_matrix)

def test_cholesky(self):
#Assured positive-definite hermitian matrixs
random_matrix = np.random.rand(10, 10)
random_matrix = random_matrix @ random_matrix.T.conj()
L = decompositions.cholesky(tf, random_matrix, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are passing the tensorflow module here, but it should be numpy

self.assertAllClose(np.linalg.cholesky(random_matrix), L)

def test_max_singular_values(self):
random_matrix = np.random.rand(10, 10)
unitary1, _, unitary2 = np.linalg.svd(random_matrix)
Expand Down
17 changes: 17 additions & 0 deletions tensornetwork/backends/pytorch/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ def svd(
return u, s, vh, s_rest


def cholesky(
torch: Any,
tensor: Tensor,
pivot_axis: int
) -> Tuple[Tensor, Tensor]:
"""
Computes the Cholesky decomposition of a tensor

See tensornetwork.backends.tensorflow.decompositions for details.
"""
left_dims = list(tensor.shape)[:pivot_axis]
right_dims = list(tensor.shape)[pivot_axis:]

tensor = torch.reshape(tensor, (np.prod(left_dims), np.prod(right_dims)))
L = np.linalg.cholesky(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to call the torch version of choleksy here, not numpy.

return L

def qr(
torch: Any,
tensor: Tensor,
Expand Down
10 changes: 10 additions & 0 deletions tensornetwork/backends/pytorch/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import numpy as np
import torch
from tensornetwork.backends.pytorch import decompositions
import pytest

np_dtypes = [np.float64, np.complex128]

def test_expected_shapes():
val = torch.zeros((2, 3, 4, 5))
Expand All @@ -42,6 +44,14 @@ def test_expected_shapes_rq():
assert r.shape == (2, 3, 6)
assert q.shape == (6, 4, 5)

# @pytest.mark.parametrize("dtype", np_dtypes)
def test_cholesky():
#Assured positive-definite hermitian matrix
random_matrix = np.random.rand(10, 10)
random_matrix = random_matrix @ random_matrix.T.conj()
random_matrix = torch.from_numpy(random_matrix)
L = decompositions.cholesky(torch, random_matrix, 1)
np.testing.assert_allclose(torch.cholesky(random_matrix), L)

def test_rq():
random_matrix = torch.rand([10, 10], dtype=torch.float64)
Expand Down
20 changes: 20 additions & 0 deletions tensornetwork/backends/tensorflow/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,23 @@ def rq(
r = tf.reshape(r, tf.concat([left_dims, [center_dim]], axis=-1))
q = tf.reshape(q, tf.concat([[center_dim], right_dims], axis=-1))
return r, q


def cholesky(
tf: Any,
tensor: Tensor,
pivot_axis: int
) -> Tuple[Tensor, Tensor]:
""" Computes de cholesky decomposition of a tensor.

Returns the Cholesky decomposition of a tensor which we treat as a
square matrix
"""
left_dims = tf.shape(tensor)[:pivot_axis]
right_dims = tf.shape(tensor)[pivot_axis:]
tensor = tf.reshape(tensor,
[tf.reduce_prod(left_dims),
tf.reduce_prod(right_dims)])
L = tf.linalg.cholesky(tensor)
return L

10 changes: 10 additions & 0 deletions tensornetwork/backends/tensorflow/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import numpy as np
import tensorflow as tf
from tensornetwork.backends.tensorflow import decompositions
import pytest

np_dtypes = [np.float64, np.complex128]

class DecompositionsTest(tf.test.TestCase):

Expand Down Expand Up @@ -54,6 +56,14 @@ def test_qr(self):
q, r = decompositions.qr(tf, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(tf.tensordot(q, r, ([1], [0])), random_matrix)

# @pytest.mark.parametrize("dtype", np_dtypes)
def test_cholesky(self):
#Assured positive-definite hermitian matrix
random_matrix = np.random.rand(10, 10)
random_matrix = random_matrix @ random_matrix.T.conj()
L = decompositions.cholesky(tf, random_matrix, 1)
self.assertAllClose(np.linalg.cholesky(random_matrix), L)

def test_rq_defun(self):
random_matrix = np.random.rand(10, 10)
for non_negative_diagonal in [True, False]:
Expand Down