From 2239a6f737d2b9282e860820fa332e5e911d553e Mon Sep 17 00:00:00 2001 From: Andreas Borgen Longva Date: Fri, 17 Feb 2017 23:58:58 +0100 Subject: [PATCH] Rewrite forward_substitution and back_substitution (#152) * Add benchmarks for solve_?_triangular * Rewrite forward_substitution with utils::dot This lets us leverage the optimized implementation of dot, and the resulting code is a lot cleaner. Plus, we avoid an unnecessary allocation to boot. * Rewrite back_substitution with utils::dot * Update comments in back_substitution * Remove Any trait bounds for substitution * Remove Any import and restore utils import --- benches/lib.rs | 2 + benches/linalg/triangular.rs | 58 ++++++++++++++++ src/matrix/mod.rs | 124 +++++++++++++++++++++-------------- tests/mat/mod.rs | 20 +++--- 4 files changed, 144 insertions(+), 60 deletions(-) create mode 100644 benches/linalg/triangular.rs diff --git a/benches/lib.rs b/benches/lib.rs index 0699c93..be53bfd 100644 --- a/benches/lib.rs +++ b/benches/lib.rs @@ -1,5 +1,6 @@ #![feature(test)] +#[macro_use] extern crate rulinalg; extern crate num as libnum; extern crate test; @@ -11,6 +12,7 @@ pub mod linalg { mod svd; mod lu; mod norm; + mod triangular; mod permutation; pub mod util; } diff --git a/benches/linalg/triangular.rs b/benches/linalg/triangular.rs new file mode 100644 index 0000000..16648a4 --- /dev/null +++ b/benches/linalg/triangular.rs @@ -0,0 +1,58 @@ +use test::Bencher; +use rulinalg::matrix::Matrix; +use rulinalg::matrix::BaseMatrix; + +#[bench] +fn solve_l_triangular_100x100(b: &mut Bencher) { + let n = 100; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_l_triangular(vector![0.0; n]) + }); +} + +#[bench] +fn solve_l_triangular_1000x1000(b: &mut Bencher) { + let n = 1000; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_l_triangular(vector![0.0; n]) + }); +} + +#[bench] +fn solve_l_triangular_10000x10000(b: &mut Bencher) { + let n = 10000; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_l_triangular(vector![0.0; n]) + }); +} + +#[bench] +fn solve_u_triangular_100x100(b: &mut Bencher) { + let n = 100; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_u_triangular(vector![0.0; n]) + }); +} + +#[bench] +fn solve_u_triangular_1000x1000(b: &mut Bencher) { + let n = 1000; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_u_triangular(vector![0.0; n]) + }); +} + +#[bench] +fn solve_u_triangular_10000x10000(b: &mut Bencher) { + let n = 10000; + let x = Matrix::::identity(n); + b.iter(|| { + x.solve_u_triangular(vector![0.0; n]) + }); +} + diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index c74e93a..6c468af 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -7,13 +7,14 @@ //! via `BaseMatrix` and `BaseMatrixMut` trait. use std; -use std::any::Any; use std::marker::PhantomData; use libnum::Float; use error::{Error, ErrorKind}; use vector::Vector; +use utils; + pub mod decomposition; mod base; mod deref; @@ -310,63 +311,88 @@ pub struct SliceIterMut<'a, T: 'a> { _marker: PhantomData<&'a mut T>, } -/// Back substitution -fn back_substitution(m: &M, y: Vector) -> Result, Error> - where T: Any + Float, +/// Solves the system Ux = y by back substitution. +/// +/// Here U is an upper triangular matrix and y a vector +/// which is dimensionally compatible with U. +fn back_substitution(u: &M, y: Vector) -> Result, Error> + where T: Float, M: BaseMatrix { - if m.is_empty() { - return Err(Error::new(ErrorKind::InvalidArg, "Matrix is empty.")); - } - - let mut x = vec![T::zero(); y.size()]; - - unsafe { - for i in (0..y.size()).rev() { - let mut holding_u_sum = T::zero(); - for j in (i + 1..y.size()).rev() { - holding_u_sum = holding_u_sum + *m.get_unchecked([i, j]) * x[j]; - } - - let diag = *m.get_unchecked([i, i]); - if diag.abs() < T::min_positive_value() + T::min_positive_value() { - return Err(Error::new(ErrorKind::AlgebraFailure, - "Linear system cannot be solved (matrix is singular).")); - } - x[i] = (y[i] - holding_u_sum) / diag; + assert!(u.rows() == u.cols(), "Matrix U must be square."); + assert!(y.size() == u.rows(), + "Matrix and RHS vector must be dimensionally compatible."); + let mut x = y; + + let n = u.rows(); + for i in (0 .. n).rev() { + let row = u.row(i); + + // TODO: Remove unsafe once `get` is available in `BaseMatrix` + let divisor = unsafe { u.get_unchecked([i, i]).clone() }; + if divisor.abs() < T::epsilon() { + return Err(Error::new(ErrorKind::DivByZero, + "Lower triangular matrix is singular to working precision.")); } + + // We have + // u[i, i] x[i] = b[i] - sum_j { u[i, j] * x[j] } + // where j = i + 1, ..., (n - 1) + // + // Note that the right-hand side sum term can be rewritten as + // u[i, (i + 1) .. n] * x[(i + 1) .. n] + // where * denotes the dot product. + // This is handy, because we have a very efficient + // dot(., .) implementation! + let dot = { + let row_part = &row.raw_slice()[(i + 1) .. n]; + let x_part = &x.data()[(i + 1) .. n]; + utils::dot(row_part, x_part) + }; + + x[i] = (x[i] - dot) / divisor; } - Ok(Vector::new(x)) + Ok(x) } -/// forward substitution -fn forward_substitution(m: &M, y: Vector) -> Result, Error> - where T: Any + Float, +/// Solves the system Lx = y by forward substitution. +/// +/// Here, L is a square, lower triangular matrix and y +/// is a vector which is dimensionally compatible with L. +fn forward_substitution(l: &M, y: Vector) -> Result, Error> + where T: Float, M: BaseMatrix { - if m.is_empty() { - return Err(Error::new(ErrorKind::InvalidArg, "Matrix is empty.")); - } - - let mut x = Vec::with_capacity(y.size()); - - unsafe { - for (i, y_item) in y.data().iter().enumerate().take(y.size()) { - let mut holding_l_sum = T::zero(); - for (j, x_item) in x.iter().enumerate().take(i) { - holding_l_sum = holding_l_sum + *m.get_unchecked([i, j]) * *x_item; - } - - let diag = *m.get_unchecked([i, i]); - - if diag.abs() < T::min_positive_value() + T::min_positive_value() { - return Err(Error::new(ErrorKind::AlgebraFailure, - "Linear system cannot be solved (matrix is singular).")); - } - x.push((*y_item - holding_l_sum) / diag); + assert!(l.rows() == l.cols(), "Matrix L must be square."); + assert!(y.size() == l.rows(), + "Matrix and RHS vector must be dimensionally compatible."); + let mut x = y; + + for (i, row) in l.row_iter().enumerate() { + // TODO: Remove unsafe once `get` is available in `BaseMatrix` + let divisor = unsafe { l.get_unchecked([i, i]).clone() }; + if divisor.abs() < T::epsilon() { + return Err(Error::new(ErrorKind::DivByZero, + "Lower triangular matrix is singular to working precision.")); } - } - Ok(Vector::new(x)) + // We have + // l[i, i] x[i] = b[i] - sum_j { l[i, j] * x[j] } + // where j = 0, ..., i - 1 + // + // Note that the right-hand side sum term can be rewritten as + // l[i, 0 .. i] * x[0 .. i] + // where * denotes the dot product. + // This is handy, because we have a very efficient + // dot(., .) implementation! + let dot = { + let row_part = &row.raw_slice()[0 .. i]; + let x_part = &x.data()[0 .. i]; + utils::dot(row_part, x_part) + }; + + x[i] = (x[i] - dot) / divisor; + } + Ok(x) } diff --git a/tests/mat/mod.rs b/tests/mat/mod.rs index b71b860..3cc45ac 100644 --- a/tests/mat/mod.rs +++ b/tests/mat/mod.rs @@ -1,4 +1,4 @@ -use rulinalg::matrix::{BaseMatrix, Matrix}; +use rulinalg::matrix::{BaseMatrix}; #[test] fn test_solve() { @@ -15,26 +15,24 @@ fn test_solve() { let b = vector![-100.0, 0.0, 0.0, -100.0, 0.0, 0.0, -100.0, 0.0, 0.0]; let c = a.solve(b).unwrap(); - let true_solution = vec![42.85714286, 18.75, 7.14285714, 52.67857143, - 25.0, 9.82142857, 42.85714286, 18.75, 7.14285714]; - - assert!(c.into_iter().zip(true_solution.into_iter()).all(|(x, y)| (x-y) < 1e-5)); - + let true_solution = vector![42.85714286, 18.75, 7.14285714, 52.67857143, + 25.0, 9.82142857, 42.85714286, 18.75, 7.14285714]; + + // Note: the "true_solution" given here has way too few + // significant digits, and since I can't be bothered to enter + // it all into e.g. NumPy, I'm leaving a lower absolute + // tolerance in place. + assert_vector_eq!(c, true_solution, comp = abs, tol = 1e-8); } #[test] fn test_l_triangular_solve_errs() { - let a: Matrix = matrix![]; - assert!(a.solve_l_triangular(vector![]).is_err()); let a = matrix![0.0]; assert!(a.solve_l_triangular(vector![1.0]).is_err()); } #[test] fn test_u_triangular_solve_errs() { - let a: Matrix = matrix![]; - assert!(a.solve_u_triangular(vector![]).is_err());; - let a = matrix![0.0]; assert!(a.solve_u_triangular(vector![1.0]).is_err()); }