-
Notifications
You must be signed in to change notification settings - Fork 4
/
matrixFactory.h
88 lines (71 loc) · 2.46 KB
/
matrixFactory.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#ifndef LINALG_MATRIXFACTORY_H
#define LINALG_MATRIXFACTORY_H
#include <exception>
#include <random>
#include <vector>
#include <array>
#include "matrix.h"
namespace MatrixFactory {
template <typename T>
Matrix<T> IdentityMatrix(size_t size) {
static_assert(std::is_arithmetic<T>::value, "C must be numeric");
if(size == 0) {
throw std::domain_error("Invalid defined matrix.");
}
const size_t arrSize = size*size;
std::vector<T> vec(arrSize);
for(size_t index = 0; index < size; ++index) {
vec[index*size + index] = 1;
}
Matrix<T> matrix(size, size, &vec[0]);
return matrix;
}
template<typename T>
struct Range {
static_assert(std::is_arithmetic<T>::value, "C must be numeric");
size_t rows{0};
size_t columns{0};
T from;
T to;
};
template <typename T>
Matrix<T> RandomMatrix(const Range<T> &range) {
static_assert(std::is_arithmetic<T>::value, "C must be numeric");
if(range.rows == 0 || range.columns == 0) {
throw std::domain_error("Invalid defined matrix.");
}
Matrix<T> matrix(range.rows, range.columns);
std::mt19937 mt(std::random_device{}());
std::uniform_real_distribution<> real_dist(range.from, range.to);
const auto gen = std::bind(std::ref(real_dist), std::ref(mt));
for(size_t row = 0; row < range.rows; ++row) {
for(size_t column = 0; column < range.columns; ++column) {
matrix(row, column) = gen();
}
}
return matrix;
}
}
TEST_SUITE("MatrixFactory test suite") {
TEST_CASE ("Identity Matrix") {
// |1 0 0 0|
// |0 1 0 0|
// |0 0 1 0|
// |0 0 0 1|
Matrix<int> identity = MatrixFactory::IdentityMatrix<int>(4);
Matrix<int> expected = {
4, 4, (std::array<int, 16>{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}).data()
};
CHECK(TestUtils::CompareMatrix(identity, expected));
}
TEST_CASE ("Random Matrix") {
Matrix<double> A = MatrixFactory::RandomMatrix<double>({4, 4, 1, 5});
for(size_t row = 0; row < A.rows(); ++row) {
for(size_t column = 0; column < A.columns(); ++column) {
const bool isValid = A(row, column) >= 1 && A(row, column) <= 5;
CHECK(isValid);
}
}
}
}
#endif //LINALG_MATRIXFACTORY_H