Skip to content

Commit

Permalink
Rewrite forward_substitution and back_substitution (#152)
Browse files Browse the repository at this point in the history
* Add benchmarks for solve_?_triangular

* Rewrite forward_substitution with utils::dot

This lets us leverage the optimized implementation
of dot, and the resulting code is a lot cleaner.
Plus, we avoid an unnecessary allocation to boot.

* Rewrite back_substitution with utils::dot

* Update comments in back_substitution

* Remove Any trait bounds for substitution

* Remove Any import and restore utils import
  • Loading branch information
Andlon authored and AtheMathmo committed Feb 17, 2017
1 parent bd67aa6 commit 2239a6f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 60 deletions.
2 changes: 2 additions & 0 deletions benches/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(test)]

#[macro_use]
extern crate rulinalg;
extern crate num as libnum;
extern crate test;
Expand All @@ -11,6 +12,7 @@ pub mod linalg {
mod svd;
mod lu;
mod norm;
mod triangular;
mod permutation;
pub mod util;
}
58 changes: 58 additions & 0 deletions benches/linalg/triangular.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use test::Bencher;
use rulinalg::matrix::Matrix;
use rulinalg::matrix::BaseMatrix;

#[bench]
fn solve_l_triangular_100x100(b: &mut Bencher) {
let n = 100;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_l_triangular(vector![0.0; n])
});
}

#[bench]
fn solve_l_triangular_1000x1000(b: &mut Bencher) {
let n = 1000;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_l_triangular(vector![0.0; n])
});
}

#[bench]
fn solve_l_triangular_10000x10000(b: &mut Bencher) {
let n = 10000;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_l_triangular(vector![0.0; n])
});
}

#[bench]
fn solve_u_triangular_100x100(b: &mut Bencher) {
let n = 100;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_u_triangular(vector![0.0; n])
});
}

#[bench]
fn solve_u_triangular_1000x1000(b: &mut Bencher) {
let n = 1000;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_u_triangular(vector![0.0; n])
});
}

#[bench]
fn solve_u_triangular_10000x10000(b: &mut Bencher) {
let n = 10000;
let x = Matrix::<f64>::identity(n);
b.iter(|| {
x.solve_u_triangular(vector![0.0; n])
});
}

124 changes: 75 additions & 49 deletions src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
//! via `BaseMatrix` and `BaseMatrixMut` trait.
use std;
use std::any::Any;
use std::marker::PhantomData;
use libnum::Float;

use error::{Error, ErrorKind};
use vector::Vector;

use utils;

pub mod decomposition;
mod base;
mod deref;
Expand Down Expand Up @@ -310,63 +311,88 @@ pub struct SliceIterMut<'a, T: 'a> {
_marker: PhantomData<&'a mut T>,
}

/// Back substitution
fn back_substitution<T, M>(m: &M, y: Vector<T>) -> Result<Vector<T>, Error>
where T: Any + Float,
/// Solves the system Ux = y by back substitution.
///
/// Here U is an upper triangular matrix and y a vector
/// which is dimensionally compatible with U.
fn back_substitution<T, M>(u: &M, y: Vector<T>) -> Result<Vector<T>, Error>
where T: Float,
M: BaseMatrix<T>
{
if m.is_empty() {
return Err(Error::new(ErrorKind::InvalidArg, "Matrix is empty."));
}

let mut x = vec![T::zero(); y.size()];

unsafe {
for i in (0..y.size()).rev() {
let mut holding_u_sum = T::zero();
for j in (i + 1..y.size()).rev() {
holding_u_sum = holding_u_sum + *m.get_unchecked([i, j]) * x[j];
}

let diag = *m.get_unchecked([i, i]);
if diag.abs() < T::min_positive_value() + T::min_positive_value() {
return Err(Error::new(ErrorKind::AlgebraFailure,
"Linear system cannot be solved (matrix is singular)."));
}
x[i] = (y[i] - holding_u_sum) / diag;
assert!(u.rows() == u.cols(), "Matrix U must be square.");
assert!(y.size() == u.rows(),
"Matrix and RHS vector must be dimensionally compatible.");
let mut x = y;

let n = u.rows();
for i in (0 .. n).rev() {
let row = u.row(i);

// TODO: Remove unsafe once `get` is available in `BaseMatrix`
let divisor = unsafe { u.get_unchecked([i, i]).clone() };
if divisor.abs() < T::epsilon() {
return Err(Error::new(ErrorKind::DivByZero,
"Lower triangular matrix is singular to working precision."));
}

// We have
// u[i, i] x[i] = b[i] - sum_j { u[i, j] * x[j] }
// where j = i + 1, ..., (n - 1)
//
// Note that the right-hand side sum term can be rewritten as
// u[i, (i + 1) .. n] * x[(i + 1) .. n]
// where * denotes the dot product.
// This is handy, because we have a very efficient
// dot(., .) implementation!
let dot = {
let row_part = &row.raw_slice()[(i + 1) .. n];
let x_part = &x.data()[(i + 1) .. n];
utils::dot(row_part, x_part)
};

x[i] = (x[i] - dot) / divisor;
}

Ok(Vector::new(x))
Ok(x)
}

/// forward substitution
fn forward_substitution<T, M>(m: &M, y: Vector<T>) -> Result<Vector<T>, Error>
where T: Any + Float,
/// Solves the system Lx = y by forward substitution.
///
/// Here, L is a square, lower triangular matrix and y
/// is a vector which is dimensionally compatible with L.
fn forward_substitution<T, M>(l: &M, y: Vector<T>) -> Result<Vector<T>, Error>
where T: Float,
M: BaseMatrix<T>
{
if m.is_empty() {
return Err(Error::new(ErrorKind::InvalidArg, "Matrix is empty."));
}

let mut x = Vec::with_capacity(y.size());

unsafe {
for (i, y_item) in y.data().iter().enumerate().take(y.size()) {
let mut holding_l_sum = T::zero();
for (j, x_item) in x.iter().enumerate().take(i) {
holding_l_sum = holding_l_sum + *m.get_unchecked([i, j]) * *x_item;
}

let diag = *m.get_unchecked([i, i]);

if diag.abs() < T::min_positive_value() + T::min_positive_value() {
return Err(Error::new(ErrorKind::AlgebraFailure,
"Linear system cannot be solved (matrix is singular)."));
}
x.push((*y_item - holding_l_sum) / diag);
assert!(l.rows() == l.cols(), "Matrix L must be square.");
assert!(y.size() == l.rows(),
"Matrix and RHS vector must be dimensionally compatible.");
let mut x = y;

for (i, row) in l.row_iter().enumerate() {
// TODO: Remove unsafe once `get` is available in `BaseMatrix`
let divisor = unsafe { l.get_unchecked([i, i]).clone() };
if divisor.abs() < T::epsilon() {
return Err(Error::new(ErrorKind::DivByZero,
"Lower triangular matrix is singular to working precision."));
}
}

Ok(Vector::new(x))
// We have
// l[i, i] x[i] = b[i] - sum_j { l[i, j] * x[j] }
// where j = 0, ..., i - 1
//
// Note that the right-hand side sum term can be rewritten as
// l[i, 0 .. i] * x[0 .. i]
// where * denotes the dot product.
// This is handy, because we have a very efficient
// dot(., .) implementation!
let dot = {
let row_part = &row.raw_slice()[0 .. i];
let x_part = &x.data()[0 .. i];
utils::dot(row_part, x_part)
};

x[i] = (x[i] - dot) / divisor;
}
Ok(x)
}
20 changes: 9 additions & 11 deletions tests/mat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rulinalg::matrix::{BaseMatrix, Matrix};
use rulinalg::matrix::{BaseMatrix};

#[test]
fn test_solve() {
Expand All @@ -15,26 +15,24 @@ fn test_solve() {
let b = vector![-100.0, 0.0, 0.0, -100.0, 0.0, 0.0, -100.0, 0.0, 0.0];

let c = a.solve(b).unwrap();
let true_solution = vec![42.85714286, 18.75, 7.14285714, 52.67857143,
25.0, 9.82142857, 42.85714286, 18.75, 7.14285714];

assert!(c.into_iter().zip(true_solution.into_iter()).all(|(x, y)| (x-y) < 1e-5));

let true_solution = vector![42.85714286, 18.75, 7.14285714, 52.67857143,
25.0, 9.82142857, 42.85714286, 18.75, 7.14285714];

// Note: the "true_solution" given here has way too few
// significant digits, and since I can't be bothered to enter
// it all into e.g. NumPy, I'm leaving a lower absolute
// tolerance in place.
assert_vector_eq!(c, true_solution, comp = abs, tol = 1e-8);
}

#[test]
fn test_l_triangular_solve_errs() {
let a: Matrix<f64> = matrix![];
assert!(a.solve_l_triangular(vector![]).is_err());
let a = matrix![0.0];
assert!(a.solve_l_triangular(vector![1.0]).is_err());
}

#[test]
fn test_u_triangular_solve_errs() {
let a: Matrix<f64> = matrix![];
assert!(a.solve_u_triangular(vector![]).is_err());;

let a = matrix![0.0];
assert!(a.solve_u_triangular(vector![1.0]).is_err());
}
Expand Down

0 comments on commit 2239a6f

Please sign in to comment.