Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
darioizzo committed Oct 31, 2023
1 parent 1d7f851 commit 0effd03
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 8 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/mascon.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/vsop2013.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/cr3bp.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/ffnn.cpp"
# Math functions.
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/kepE.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/kepF.cpp"
Expand Down
1 change: 1 addition & 0 deletions include/heyoka/model/ffnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace model
// The expression will contain the weights and biases of the neural network flattened into `pars` with the following conventions:
//
// from the left to right layer of parameters: [flattened weights1, flattened weights2, ... , biases1, bises2, ...]
//
HEYOKA_DLL_PUBLIC std::vector<expression>
ffnn_impl(const std::vector<expression> &, std::uint32_t, const std::vector<std::uint32_t> &,
const std::vector<std::function<expression(const expression &)>> &,
Expand Down
55 changes: 47 additions & 8 deletions src/model/ffnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <algorithm>
#include <cmath>
#include <iostream>
#include <vector>

#include <fmt/core.h>
#include <fmt/ranges.h>

#include <heyoka/config.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/model/ffnn.hpp>
Expand Down Expand Up @@ -65,8 +70,13 @@ std::vector<expression> compute_layer(std::uint32_t layer_id, const std::vector<
auto n_neurons_curr_layer = n_neurons[layer_id];

std::vector<expression> retval(n_neurons_curr_layer, 0_dbl);
fmt::print("nneurons: {}", n_neurons_prev_layer);
std::cout << std::endl;

for (std::uint32_t i = 0u; i < n_neurons_curr_layer; ++i) {
for (std::uint32_t j = 0u; j < n_neurons_prev_layer; ++j) {
fmt::print("layer, i, j: {}, {}, {}", layer_id, i, j);
std::cout << std::endl;
retval[i] += net_wb[flattenw(i, j, n_neurons, layer_id)] * inputs[j];
}
retval[i] += net_wb[flattenb(i, n_neurons, layer_id, n_net_w)];
Expand All @@ -76,38 +86,67 @@ std::vector<expression> compute_layer(std::uint32_t layer_id, const std::vector<
}
} // namespace detail

HEYOKA_DLL_PUBLIC std::vector<expression> ffnn_impl(
std::vector<expression> ffnn_impl(
// NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
const std::vector<expression> &in, std::uint32_t n_out,
const std::vector<std::uint32_t> &n_neurons_per_hidden_layer,
const std::vector<std::function<expression(const expression &)>> &activations,
const std::vector<expression> &net_wb)
{
// Sanity check (should be a throw check?)
assert(n_neurons_per_hidden_layer.size() + 1 == activations.size());
// Sanity checks
if (n_neurons_per_hidden_layer.size() + 1 != activations.size()) {
throw std::invalid_argument(fmt::format(
"The number of hidden layers, as detected from the inputs, was {}, while"
"the number of activation function supplied was {}. A FFNN needs exactly one more activation function "
"than the number of hidden layers.",
n_neurons_per_hidden_layer.size(), activations.size()));
}
if (in.empty()) {
throw std::invalid_argument("The inputs provided to the ffnn seem to be an empty vector.");
}
if (n_out == 0) {
throw std::invalid_argument("The number of network outputs cannot be zero.");
}
if (!std::all_of(n_neurons_per_hidden_layer.begin(), n_neurons_per_hidden_layer.end(),
[](std::uint32_t item) { return item > 0; })) {
throw std::invalid_argument("The number of neurons for each hidden layer must be greater than zero!");
}
if (n_neurons_per_hidden_layer.empty()) { // TODO(darioizzo): maybe this is actually a wanted corner case, remove?
throw std::invalid_argument("The number of hidden layers cannot be zero.");
}

// Number of hidden layers (defined as all neuronal columns that are nor input nor output neurons)
auto n_hidden_layers = boost::numeric_cast<std::uint32_t>(n_neurons_per_hidden_layer.size());
// Number of neuronal layers (counting input and output)
auto n_layers = n_hidden_layers + 2;
// Number o
// Number of inputs
auto n_in = boost::numeric_cast<std::uint32_t>(in.size());
// Number of neurons per neuronal layer
std::vector<std::uint32_t> n_neurons = n_neurons_per_hidden_layer;
n_neurons.insert(n_neurons.begin(), n_in);
n_neurons.insert(n_neurons.end(), n_out);
// Number of network parameters
// Number of network parameters (wb: weights and biases, w: only weights)
std::uint32_t n_net_wb = 0u;
std::uint32_t n_net_w = 0u;
for (std::uint32_t i = 1u; i < n_layers; ++i) {
n_net_wb += n_neurons[i - 1] * n_neurons[i];
n_net_w += n_neurons[i - 1] * n_neurons[i];
n_net_wb += n_neurons[i];
}
// Sanity check (should be a throw check?)
assert(net_wb.size() == n_net_wb);
std::vector<expression> retval{};
// Sanity check
if (net_wb.size() != n_net_wb) {
throw std::invalid_argument(fmt::format(
"The number of network parameters, detected from its structure to be {}, does not match the size of"
"the corresponding expressions {} ",
n_net_wb, net_wb.size()));
}

// Now we build the expressions recursively going from layer to layer (L = f(Wx+b)))

std::vector<expression> retval = in;
for (std::uint32_t i = 1u; i < n_layers; ++i) {
fmt::print("{},{}", i, n_neurons[i]);
std::cout << std::endl;
retval = detail::compute_layer(i, retval, n_neurons, activations[i], net_wb, n_net_w);
}
return retval;
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ ADD_HEYOKA_TESTCASE(model_fixed_centres)
ADD_HEYOKA_TESTCASE(model_rotating)
ADD_HEYOKA_TESTCASE(model_mascon)
ADD_HEYOKA_TESTCASE(model_cr3bp)
ADD_HEYOKA_TESTCASE(model_ffnn)
ADD_HEYOKA_TESTCASE(step_callback)
ADD_HEYOKA_TESTCASE(llvm_state_mem_cache)

Expand Down

0 comments on commit 0effd03

Please sign in to comment.