Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Dario Izzo committed Oct 31, 2023
1 parent 67b8b44 commit 79ac7f9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
20 changes: 16 additions & 4 deletions src/model/ffnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,30 @@ std::vector<expression> compute_layer(std::uint32_t layer_id, const std::vector<
auto n_neurons_curr_layer = n_neurons[layer_id];

std::vector<expression> retval(n_neurons_curr_layer, 0_dbl);
fmt::print("nneurons: {}", n_neurons_prev_layer);
fmt::print("net_wb: {}\n", net_wb.size());
std::cout << std::endl;

for (std::uint32_t i = 0u; i < n_neurons_curr_layer; ++i) {
for (std::uint32_t j = 0u; j < n_neurons_prev_layer; ++j) {
fmt::print("layer, i, j: {}, {}, {}", layer_id, i, j);
fmt::print("layer, i, j, idx: {}, {}, {}\n", layer_id, i, j, flattenw(i, j, n_neurons, layer_id));
std::cout << std::endl;
retval[i] += net_wb[flattenw(i, j, n_neurons, layer_id)] * inputs[j];
retval[i] += 1_dbl;//net_wb[flattenw(i, j, n_neurons, layer_id)] * inputs[j];
}
retval[i] += net_wb[flattenb(i, n_neurons, layer_id, n_net_w)];
fmt::print("idxb {}\n", flattenb(i, n_neurons, layer_id, n_net_w));
std::cout << std::endl;

retval[i]+= 1_dbl; //net_wb[flattenb(i, n_neurons, layer_id, n_net_w)];

fmt::print("\n{}\n", retval[i]);
fmt::print("Here1");

std::cout << std::endl;
retval[i] = activation(retval[i]);
fmt::print("Here2");
std::cout << std::endl;
}
fmt::print("Here3");

return retval;
}
} // namespace detail
Expand Down
3 changes: 2 additions & 1 deletion test/model_ffnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace heyoka;

TEST_CASE("impl")
{
auto linear = [](expression ret) -> expression { return ret; };
auto [x] = make_vars("x");
auto my_net = model::ffnn_impl({x}, 2, {2, 2}, {heyoka::tanh, heyoka::tanh, [](auto ret) { return ret; }},
auto my_net = model::ffnn_impl({x}, 2, {2, 2}, {heyoka::tanh, heyoka::tanh, heyoka::tanh},
{1_dbl, 2_dbl, 3_dbl, 4_dbl, 5_dbl, 6_dbl, 7_dbl, 8_dbl, 9_dbl, 0_dbl, 1_dbl, 2_dbl,
3_dbl, 4_dbl, 5_dbl, 6_dbl});
}

0 comments on commit 79ac7f9

Please sign in to comment.