diff --git a/torchsparsegradutils/tests/test_linear_cg.py b/torchsparsegradutils/tests/test_linear_cg.py index 2cf1f6d..c9e881d 100644 --- a/torchsparsegradutils/tests/test_linear_cg.py +++ b/torchsparsegradutils/tests/test_linear_cg.py @@ -126,6 +126,26 @@ def test_batch_cg_with_tridiag(self): approx_eigs = torch.linalg.eigvalsh(t_mats[j, i]) self.assertTrue(torch.allclose(eigs, approx_eigs, atol=1e-3, rtol=1e-4)) + def test_batch_cg_init(self): + batch = 5 + size = 100 + matrix = torch.randn(batch, size, size, dtype=torch.float64) + matrix = matrix.matmul(matrix.mT) + matrix.div_(matrix.norm()) + matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1)) + + # Initial solve + rhs = torch.randn(batch, size, 50, dtype=torch.float64) + solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size, max_tridiag_iter=0) + + # Initialize with solve + solves_with_init = linear_cg(matrix.matmul, rhs=rhs, max_iter=1, initial_guess=solves, max_tridiag_iter=0) + + # Check cg + matrix_chol = torch.linalg.cholesky(matrix) + actual = torch.cholesky_solve(rhs, matrix_chol) + self.assertTrue(torch.allclose(solves_with_init, actual, atol=1e-3, rtol=1e-4)) + if __name__ == "__main__": unittest.main() diff --git a/torchsparsegradutils/utils/linear_cg.py b/torchsparsegradutils/utils/linear_cg.py index 6f3f4d9..31ae54d 100644 --- a/torchsparsegradutils/utils/linear_cg.py +++ b/torchsparsegradutils/utils/linear_cg.py @@ -184,6 +184,7 @@ def linear_cg( # Let's normalize. We'll un-normalize afterwards rhs = rhs.div(rhs_norm) + initial_guess = initial_guess.div(rhs_norm) # residual: residual_{0} = b_vec - lhs x_{0} residual = rhs - matmul_closure(initial_guess)