From 008e912e2fa823bd0adc9601ce97987d5f36554d Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 1 Nov 2021 10:11:54 -0700 Subject: [PATCH 1/3] added template for data reader to pass conduit node from driver added conduit to cmakelist fixed error with global_trainer_ added simple conduit datareader to hold conduit node prototyping use of data_store to hold conduit nodes fixing bug with input buffers not being sized correctly fixed problem with unpacking conduit node moving {trainer, dc, dr, ds setup} and {loading inference samples} to separate functions extended core API for many different input types removed old code from first lbann-core impl added simple run script Fix things that have drifted in LBANN Get core-drive compiling again clang-format batch_functional_inference_algorithm Steps toward debugging the segfault in the inference algo test The test no longer segfaults. Now it just fails. Don't shuffle when setting up for inference Fix a spacing issue Updated CMake to install the core driver Build the core-driver --- CMakeLists.txt | 1 + cmake/configure_files/LBANNConfig.cmake.in | 2 +- core-driver/CMakeLists.txt | 21 ++- core-driver/main.cpp | 107 +++++++++-- core-driver/run.sh | 10 ++ .../data_ingestion/readers/CMakeLists.txt | 1 + .../readers/data_reader_conduit.hpp | 72 ++++++++ .../batch_functional_inference_algorithm.hpp | 102 +---------- include/lbann/layers/io/input_layer.hpp | 9 - include/lbann/utils/lbann_library.hpp | 79 ++++++-- src/data_ingestion/readers/CMakeLists.txt | 1 + .../readers/data_reader_conduit.cpp | 53 ++++++ src/execution_algorithms/CMakeLists.txt | 1 + .../batch_functional_inference_algorithm.cpp | 90 ++++++++++ .../unit_test/inference_algorithm_test.cpp | 64 +++++-- src/layers/io/input_layer.cpp | 24 ++- src/models/unit_test/model_test.cpp | 12 +- src/trainers/trainer.cpp | 6 +- src/utils/lbann_library.cpp | 170 ++++++++++++++---- 19 files changed, 607 insertions(+), 218 deletions(-) create mode 100755 core-driver/run.sh create mode 100644 include/lbann/data_ingestion/readers/data_reader_conduit.hpp create mode 100644 src/data_ingestion/readers/data_reader_conduit.cpp create mode 100644 src/execution_algorithms/batch_functional_inference_algorithm.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 3030cde812d..1252a8dc7a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -920,6 +920,7 @@ add_subdirectory(applications/CANDLE/pilot2/tools) add_subdirectory(applications/ATOM/utils) add_subdirectory(tests) add_subdirectory(scripts) +add_subdirectory(core-driver) ################################################################ # Install LBANN diff --git a/cmake/configure_files/LBANNConfig.cmake.in b/cmake/configure_files/LBANNConfig.cmake.in index 1eb95a448af..ec2e95c7b30 100644 --- a/cmake/configure_files/LBANNConfig.cmake.in +++ b/cmake/configure_files/LBANNConfig.cmake.in @@ -74,7 +74,7 @@ set(LBANN_HAS_DIHYDROGEN @LBANN_HAS_DIHYDROGEN@) set(LBANN_HAS_DISTCONV @LBANN_HAS_DISTCONV@) set(LBANN_HAS_DOXYGEN @LBANN_HAS_DOXYGEN@) set(LBANN_HAS_EMBEDDED_PYTHON @LBANN_HAS_EMBEDDED_PYTHON@) -set(LBANN_HAS_FFTW @LBANN_HAS_FFTW@ +set(LBANN_HAS_FFTW @LBANN_HAS_FFTW@) set(LBANN_HAS_FFTW_FLOAT @LBANN_HAS_FFTW_FLOAT@) set(LBANN_HAS_FFTW_DOUBLE @LBANN_HAS_FFTW_DOUBLE@) set(LBANN_HAS_GPU_FP16 @LBANN_HAS_GPU_FP16@) diff --git a/core-driver/CMakeLists.txt b/core-driver/CMakeLists.txt index f960f0e9fae..82a8f6a56dd 100644 --- a/core-driver/CMakeLists.txt +++ b/core-driver/CMakeLists.txt @@ -1,5 +1,18 @@ -cmake_minimum_required(VERSION 3.18.0) -project(my_lbann_test C CXX) +cmake_minimum_required(VERSION 3.21.0) +project(my_lbann_test CXX) find_package(LBANN 0.102.0 REQUIRED) -add_executable(Main main.cpp) -target_link_libraries(Main PRIVATE LBANN::lbann) +find_package(Conduit CONFIG REQUIRED) +add_executable(lbann-core main.cpp) +target_link_libraries(lbann-core PRIVATE LBANN::lbann) + +#target_link_libraries(lbann-bin lbann) +set_target_properties(lbann-core + PROPERTIES + OUTPUT_NAME lbann-core-driver + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +#list(APPEND LBANN_EXE_TGTS lbann-core) + +install(TARGETS lbann-core + EXPORT LBANNTargets + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/core-driver/main.cpp b/core-driver/main.cpp index b419dc79ada..a4a2cd7a104 100644 --- a/core-driver/main.cpp +++ b/core-driver/main.cpp @@ -29,8 +29,11 @@ #include #include +// Add test-specific options void construct_opts(int argc, char **argv) { auto& arg_parser = lbann::global_argument_parser(); + lbann::construct_std_options(); + lbann::construct_datastore_options(); arg_parser.add_option("samples", {"-n"}, "Number of samples to run inference on", @@ -52,20 +55,76 @@ void construct_opts(int argc, char **argv) { "Number of labels in dataset", 10); arg_parser.add_option("minibatchsize", - {"-mbs"}, + {"--mbs"}, "Number of samples in a mini-batch", 16); + arg_parser.add_flag("use_conduit", + {"--conduit"}, + "Use Conduit node samples (Default is non-distributed matrix)"); + arg_parser.add_flag("use_dist_matrix", + {"--dist"}, + "Use Hydrogen distributed matrix (Default is non-distributed matrix)"); arg_parser.add_required_argument ("model", "Directory containing checkpointed model"); arg_parser.parse(argc, argv); } -El::DistMatrix -random_samples(El::Grid const& g, int n, int c, int h, int w) { +// Generates random samples and labels for mnist data in Hydrogen matrix +std::map< + std::string, + El::Matrix> +mat_mnist_samples(int n, int c, int h, int w) +{ + El::Matrix + samples(c * h * w, n); + El::MakeUniform(samples); + El::Matrix + labels(1, n); + El::MakeUniform(labels); + std::map< + std::string, + El::Matrix> + samples_map = {{"data/samples", samples}, {"data/labels", labels}}; + return samples_map; +} + +// Generates random samples and labels for mnist data in Hydrogen distributed matrix +std::map< + std::string, + El::DistMatrix> +distmat_mnist_samples(El::Grid const& g, int n, int c, int h, int w) +{ El::DistMatrix - samples(n, c * h * w, g); + samples(c * h * w, n, g); El::MakeUniform(samples); + El::DistMatrix + labels(1, n, g); + El::MakeUniform(labels); + std::map< + std::string, + El::DistMatrix> + samples_map = {{"data/samples", samples}, {"data/labels", labels}}; + return samples_map; +} + +// Fills array with random values +void random_fill(float *arr, int size, int max_val=255) { + for (int i; i < size; i++) { + arr[i] = (float)(std::rand() % max_val) / (float)max_val; + } +} + +// Generates random samples and labels for mnist data in vector of Conduit nodes +std::vector conduit_mnist_samples(int n, int c, int h, int w) { + std::vector samples(n); + int sample_size = c * h * w; + float this_sample[sample_size]; + for (int i; i("use_conduit") && arg_parser.get("use_dist_matrix")) { + LBANN_ERROR("Cannot use conduit node and distributed matrix together, choose one: --conduit --dist"); + } std::stringstream msg; msg << "Model: " << arg_parser.get("model") << std::endl; msg << "{ N, c, h, w } = { " << arg_parser.get("samples") << ", "; @@ -94,8 +156,8 @@ int main(int argc, char **argv) { std::cout << msg.str(); } - // Load model and run inference on samples auto lbann_comm = lbann::initialize_lbann(MPI_COMM_WORLD); + auto m = lbann::load_inference_model(lbann_comm.get(), arg_parser.get("model"), arg_parser.get("minibatchsize"), @@ -105,14 +167,31 @@ int main(int argc, char **argv) { arg_parser.get("width") }, {arg_parser.get("labels")}); - auto samples = random_samples(lbann_comm->get_trainer_grid(), - arg_parser.get("samples"), - arg_parser.get("channels"), - arg_parser.get("height"), - arg_parser.get("width")); - auto labels = lbann::infer(m.get(), - samples, - arg_parser.get("minibatchsize")); + + // three options for data generation + if (arg_parser.get("use_conduit")) { + auto samples = conduit_mnist_samples(arg_parser.get("samples"), + arg_parser.get("channels"), + arg_parser.get("height"), + arg_parser.get("width")); + lbann::set_inference_samples(samples); + } else if (arg_parser.get("use_dist_matrix")) { + auto samples = distmat_mnist_samples(lbann_comm->get_trainer_grid(), + arg_parser.get("samples"), + arg_parser.get("channels"), + arg_parser.get("height"), + arg_parser.get("width")); + lbann::set_inference_samples(samples); + } else { + auto samples = mat_mnist_samples( + arg_parser.get("samples"), + arg_parser.get("channels"), + arg_parser.get("height"), + arg_parser.get("width")); + lbann::set_inference_samples(samples); + } + + auto labels = lbann::inference(m.get()); // Print inference results if (lbann_comm->am_world_master()) { diff --git a/core-driver/run.sh b/core-driver/run.sh new file mode 100755 index 00000000000..d5cdd9f1577 --- /dev/null +++ b/core-driver/run.sh @@ -0,0 +1,10 @@ +export AL_PROGRESS_RANKS_PER_NUMA_NODE=2 +export OMP_NUM_THREADS=8 +export MV2_USE_RDMA_CM=0 + +# This should be a checkpointed lenet model +MODEL_LOC="path/to/checkpointed/model" + +./Main $MODEL_LOC +./Main $MODEL_LOC --dist +./Main $MODEL_LOC --conduit diff --git a/include/lbann/data_ingestion/readers/CMakeLists.txt b/include/lbann/data_ingestion/readers/CMakeLists.txt index 0de2d3b52a3..f6b5538af84 100644 --- a/include/lbann/data_ingestion/readers/CMakeLists.txt +++ b/include/lbann/data_ingestion/readers/CMakeLists.txt @@ -29,6 +29,7 @@ set_full_path(THIS_DIR_HEADERS metadata.hpp # Data readers data_reader_cifar10.hpp + data_reader_conduit.hpp data_reader_csv.hpp data_reader_image.hpp data_reader_HDF5.hpp diff --git a/include/lbann/data_ingestion/readers/data_reader_conduit.hpp b/include/lbann/data_ingestion/readers/data_reader_conduit.hpp new file mode 100644 index 00000000000..5103ec40cbb --- /dev/null +++ b/include/lbann/data_ingestion/readers/data_reader_conduit.hpp @@ -0,0 +1,72 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2021, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#ifndef LBANN_DATA_READER_CONDUIT_HPP +#define LBANN_DATA_READER_CONDUIT_HPP + +#include "lbann/data_readers/data_reader.hpp" +#include "lbann/data_store/data_store_conduit.hpp" + +namespace lbann { +/** + * A generalized data reader for passed in conduit nodes. + */ +class conduit_data_reader : public generic_data_reader +{ +public: + conduit_data_reader* copy() const override { return new conduit_data_reader(*this); } + bool has_conduit_output() override { return true; } + void load() override; + bool fetch_conduit_node(conduit::Node& sample, int data_id) override; + + void set_data_dims(std::vector dims); + void set_label_dims(std::vector dims); + + std::string get_type() const override { return "conduit_data_reader"; } + int get_linearized_data_size() const override { + int data_size = 1; + for(int i : m_data_dims) { + data_size *= i; + } + return data_size; + } + int get_linearized_label_size() const override { + int label_size = 1; + for(int i : m_label_dims) { + label_size *= i; + } + return label_size; + } + +protected: + std::vector m_data_dims; + std::vector m_label_dims; + +}; // END: class conduit_data_reader + +} // namespace lbann + +#endif // LBANN_DATA_READER_CONDUIT_HPP diff --git a/include/lbann/execution_algorithms/batch_functional_inference_algorithm.hpp b/include/lbann/execution_algorithms/batch_functional_inference_algorithm.hpp index 01556aa669b..806589b04ce 100644 --- a/include/lbann/execution_algorithms/batch_functional_inference_algorithm.hpp +++ b/include/lbann/execution_algorithms/batch_functional_inference_algorithm.hpp @@ -40,8 +40,7 @@ namespace lbann { * * This execution algorithm is meant for running inference using a trained * model and samples passed by the user from an external application. The - * algorithm currently assumes that there is only 1 input layer in the model, - * and the output layer is a softmax layer. + * algorithm currently assumes that the output layer is a softmax layer. */ class batch_functional_inference_algorithm { @@ -73,111 +72,16 @@ class batch_functional_inference_algorithm /** @brief Run model inference on samples and return predicted categories. * @param[in] model A trained model - * @param[in] samples A distributed matrix containing samples for model input - * @param[in] mbs The max mini-batch size * @return Matrix of predicted labels (by index) */ - template - El::Matrix - infer(observer_ptr model, - El::DistMatrix const& samples, - size_t mbs) - { - if (mbs <= 0) { - LBANN_ERROR("mini-batch size must be larger than 0"); - } - - // Make matrix for returning predicted labels - size_t samples_size = samples.Height(); - El::Matrix labels(samples_size, 1); - - // BVE FIXME - // Create an SGD_execution_context so that layer.forward_prop can get the - // mini_batch_size - This should be fixed in the future, when SGD is not so - // hard-coded into the model & layers - auto c = SGDExecutionContext(execution_mode::inference); - model->reset_mode(c, execution_mode::inference); - // Explicitly set the size of the mini-batch that the model is executing - model->set_current_mini_batch_size(mbs); - - // Infer on mini batches - for (size_t i = 0; i < samples_size; i += mbs) { - size_t mb_idx = std::min(i + mbs, samples_size); - auto mb_range = El::IR(i, mb_idx); - auto mb_samples = El::LockedView(samples, mb_range, El::ALL); - auto mb_labels = El::View(labels, mb_range, El::ALL); - - infer_mini_batch(*model, mb_samples); - get_labels(*model, mb_labels); - } - - return labels; - } + El::Matrix infer(observer_ptr model); protected: - /** @brief Run model inference on a single mini-batch of samples - * This method takes a mini-batch of samples, inserts them into the input - * layer of the model, and runs forward prop on the model. - * @param[in] model A trained model - * @param[in] samples A distributed matrix containing samples for model input - */ - template - void infer_mini_batch( - model& model, - El::DistMatrix const& samples) - { - for (int i = 0; i < model.get_num_layers(); i++) { - auto& l = model.get_layer(i); - // Insert samples into the input layer - if (l.get_type() == "input") { - auto& il = dynamic_cast&>(l); - il.set_samples(samples); - } - } - model.forward_prop(execution_mode::inference); - } - /** @brief Finds the predicted category in a models softmax layer * @param[in] model A model that has been used for inference * @param[in] labels A matrix to place predicted category labels */ - void get_labels(model& model, El::Matrix& labels) - { - int pred_label = 0; - float max, col_value; - - for (const auto* l : model.get_layers()) { - // Find the output layer - if (l->get_type() == "softmax") { - auto const& dtl = - dynamic_cast const&>(*l); - const auto& outputs = dtl.get_activations(); - - // Find the prediction for each sample - int col_count = outputs.Width(); - int row_count = outputs.Height(); - for (int i = 0; i < col_count; i++) { - max = 0; - for (int j = 0; j < row_count; j++) { - col_value = outputs.Get(i, j); - if (col_value > max) { - max = col_value; - pred_label = j; - } - } - labels(i) = pred_label; - } - } - } - } + void get_labels(model& model, El::Matrix& labels); }; } // namespace lbann diff --git a/include/lbann/layers/io/input_layer.hpp b/include/lbann/layers/io/input_layer.hpp index 02d322ef6f5..7af8b1e4829 100644 --- a/include/lbann/layers/io/input_layer.hpp +++ b/include/lbann/layers/io/input_layer.hpp @@ -151,11 +151,6 @@ class input_layer : public data_type_layer void fp_compute() override; - /** @brief Places samples in input tensors - * @param samples Distributed Matrix of samples - */ - void set_samples(const El::AbstractDistMatrix& samples); - /** * Get the dimensions of the underlying data. */ @@ -178,10 +173,6 @@ class input_layer : public data_type_layer friend cereal::access; input_layer() : input_layer(nullptr) {} - // This is to track if samples are loaded with set_samples(), if so the - // fp_compute() sample loading is no longer necessary - bool m_samples_loaded = false; - data_field_type m_data_field; #ifdef LBANN_HAS_DISTCONV diff --git a/include/lbann/utils/lbann_library.hpp b/include/lbann/utils/lbann_library.hpp index d49fdc3ac73..74b0bac3ed0 100644 --- a/include/lbann/utils/lbann_library.hpp +++ b/include/lbann/utils/lbann_library.hpp @@ -31,10 +31,70 @@ #include "lbann/models/model.hpp" #include "lbann/proto/proto_common.hpp" +#include + namespace lbann { const int lbann_default_random_seed = 42; +/** @brief Places conduit samples into data store for inference + * @param[in] samples vector of Conduit nodes, 1 node per sample + */ +void set_inference_samples(std::vector &samples); + +/** @brief Places samples into data store for inference + * @param[in] samples_map map of + */ +template +void set_inference_samples( + const std::map>& + samples_map) +{ + size_t const sample_n = samples_map.begin()->second.Width(); + std::vector conduit_samples(sample_n); + for (const auto& kv : samples_map) { + El::DistMatrixReadProxy + samples_proxy(kv.second); + auto const& samples = + samples_proxy.GetLocked(); //< DistMatrix const& + size_t const sample_size = samples.Height(); + + for (size_t i = 0; i < sample_n; i++) { + DataT const* data = + samples.LockedBuffer() + i * samples.LDim(); // 1 column = 1 sample + conduit_samples[i][kv.first].set(data, sample_size); + } + } + set_inference_samples(conduit_samples); +} + +/** @brief Places samples into data store for inference + * @param[in] samples_map map of + */ +template +void set_inference_samples( + const std::map>& samples_map) +{ + size_t const sample_n = samples_map.begin()->second.Width(); + std::vector conduit_samples(sample_n); + for (const auto& kv : samples_map) { + auto const& samples = kv.second; + size_t const sample_size = samples.Height(); + + for (size_t i = 0; i < sample_n; i++) { + DataT const* data = + samples.LockedBuffer() + i * samples.LDim(); // 1 column = 1 sample + conduit_samples[i][kv.first].set(data, sample_size); + } + } + set_inference_samples(conduit_samples); +} + /** @brief Loads a trained model from checkpoint for inference only * @param[in] lc An LBANN Communicator * @param[in] cp_dir The model checkpoint directory @@ -49,25 +109,12 @@ std::unique_ptr load_inference_model(lbann_comm* lc, std::vector input_dims, std::vector output_dims); -/** @brief Creates execution algorithm and infers on samples using a model +/** @brief Creates execution algorithm and runs inference on model * @param[in] model A trained model - * @param[in] samples A distributed matrix containing samples for model input - * @param[in] mbs The max mini-batch size * @return Matrix of predicted labels */ -template -El::Matrix -infer(observer_ptr model, - El::DistMatrix const& samples, - size_t mbs) -{ - auto inf_alg = batch_functional_inference_algorithm(); - return inf_alg.infer(model, samples, mbs); -} +El::Matrix +inference(observer_ptr model); int allocate_trainer_resources(lbann_comm* comm); diff --git a/src/data_ingestion/readers/CMakeLists.txt b/src/data_ingestion/readers/CMakeLists.txt index 6350f152191..b2610fea21f 100644 --- a/src/data_ingestion/readers/CMakeLists.txt +++ b/src/data_ingestion/readers/CMakeLists.txt @@ -27,6 +27,7 @@ set_full_path(THIS_DIR_SOURCES metadata.cpp data_reader_cifar10.cpp + data_reader_conduit.cpp data_reader_csv.cpp data_reader_image.cpp data_reader_jag_conduit.cpp diff --git a/src/data_ingestion/readers/data_reader_conduit.cpp b/src/data_ingestion/readers/data_reader_conduit.cpp new file mode 100644 index 00000000000..f691bbffe2e --- /dev/null +++ b/src/data_ingestion/readers/data_reader_conduit.cpp @@ -0,0 +1,53 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2021, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +// +// lbann_data_reader .hpp .cpp - Input data base class for training, testing +//////////////////////////////////////////////////////////////////////////////// + +#include "lbann/data_readers/data_reader.hpp" +#include "lbann/data_readers/data_reader_conduit.hpp" + +namespace lbann { + +void conduit_data_reader::load() { +} + +bool conduit_data_reader::fetch_conduit_node(conduit::Node& sample, int data_id) +{ + // get the pathname to the data, and verify it exists in the conduit::Node + const conduit::Node& node = get_data_store().get_conduit_node(data_id); + sample = node; + return true; +} + +void conduit_data_reader::set_data_dims(std::vector dims) { + m_data_dims = dims; +} + +void conduit_data_reader::set_label_dims(std::vector dims) { + m_label_dims = dims; +} + +} // namespace lbann diff --git a/src/execution_algorithms/CMakeLists.txt b/src/execution_algorithms/CMakeLists.txt index dcb4f506b9d..725f8e5adeb 100644 --- a/src/execution_algorithms/CMakeLists.txt +++ b/src/execution_algorithms/CMakeLists.txt @@ -25,6 +25,7 @@ ################################################################################ # Add the source files for this directory set_full_path(THIS_DIR_SOURCES + batch_functional_inference_algorithm.cpp execution_context.cpp factory.cpp kfac.cpp diff --git a/src/execution_algorithms/batch_functional_inference_algorithm.cpp b/src/execution_algorithms/batch_functional_inference_algorithm.cpp new file mode 100644 index 00000000000..22112664d24 --- /dev/null +++ b/src/execution_algorithms/batch_functional_inference_algorithm.cpp @@ -0,0 +1,90 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright (c) 2014-2016, Lawrence Livermore National Security, LLC. +// Produced at the Lawrence Livermore National Laboratory. +// Written by the LBANN Research Team (B. Van Essen, et al.) listed in +// the CONTRIBUTORS file. +// +// LLNL-CODE-697807. +// All rights reserved. +// +// This file is part of LBANN: Livermore Big Artificial Neural Network +// Toolkit. For details, see http://software.llnl.gov/LBANN or +// https://github.com/LLNL/LBANN. +// +// Licensed under the Apache License, Version 2.0 (the "Licensee"); you +// may not use this file except in compliance with the License. You may +// obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the license. +//////////////////////////////////////////////////////////////////////////////// + +#include "lbann/execution_algorithms/batch_functional_inference_algorithm.hpp" + +namespace lbann { + +El::Matrix +batch_functional_inference_algorithm::infer(observer_ptr model) +{ + size_t const mbs = get_trainer().get_max_mini_batch_size(); + El::Matrix labels(mbs, 1); + + auto c = SGDExecutionContext(execution_mode::inference, mbs); + model->reset_mode(c, execution_mode::inference); + get_trainer().get_data_coordinator().reset_mode(c); + + get_trainer().get_data_coordinator().fetch_data(execution_mode::inference); + model->forward_prop(execution_mode::inference); + get_labels(*model, labels); + + return labels; +} + +void batch_functional_inference_algorithm::get_labels( + model& model, + El::Matrix& labels) +{ + Layer const* softmax = nullptr; + auto const layer_list = model.get_layers(); + for (auto const* const l_tmp : layer_list) { + if (l_tmp->get_type() == "softmax") { + softmax = l_tmp; + break; + } + } + if (!softmax) + LBANN_ERROR("get_labels only supported when model contains a softmax. This " + "is a known limitation and we're working on it."); + try { + auto const& dtl = + dynamic_cast const&>(*softmax); + const auto& outputs = dtl.get_activations(); + + // Find the prediction for each sample + El::Int const col_count = outputs.Width(); + El::Int const row_count = outputs.Height(); + LBANN_ASSERT(col_count == labels.Height()); + for (El::Int col = 0; col < col_count; ++col) { + float max = 0.f; + El::Int pred_label = 0; + for (El::Int row = 0; row < row_count; ++row) { + float const col_value = outputs.Get(row, col); + if (col_value > max) { + max = col_value; + pred_label = row; + } + } + labels(col, 0) = pred_label; + } + } + catch (std::bad_cast const&) { + LBANN_ERROR("Softmax layer does not have data type \"float\""); + } +} + +} // namespace lbann diff --git a/src/execution_algorithms/unit_test/inference_algorithm_test.cpp b/src/execution_algorithms/unit_test/inference_algorithm_test.cpp index cfe64fdff77..5d3ef334a47 100644 --- a/src/execution_algorithms/unit_test/inference_algorithm_test.cpp +++ b/src/execution_algorithms/unit_test/inference_algorithm_test.cpp @@ -37,12 +37,14 @@ #include "lbann/proto/lbann.pb.h" #include +#include namespace pb = ::google::protobuf; namespace { -// This model is just an input layer into a softmax layer, so we can verify the -// output is correct for a simple input (e.g., a matrix filled with 1.0) +// This model is just an input layer into a softmax layer, so we can +// verify the output is correct for a simple input (e.g., a matrix +// filled with 1.0) std::string const model_prototext = R"ptext( model { layer { @@ -61,12 +63,15 @@ model { } )ptext"; -template -auto make_model(lbann::lbann_comm& comm, int class_n) +lbann_data::LbannPB get_model_protobuf() { lbann_data::LbannPB my_proto; if (!pb::TextFormat::ParseFromString(model_prototext, &my_proto)) throw "Parsing protobuf failed."; + return my_proto; +} +auto make_model(lbann::lbann_comm& comm, lbann_data::LbannPB& my_proto, int class_n) +{ // Construct a trainer so that the model can register the input layer auto& trainer = lbann::construct_trainer(&comm, my_proto.mutable_trainer(), my_proto); @@ -81,31 +86,49 @@ auto make_model(lbann::lbann_comm& comm, int class_n) } // namespace +namespace lbann { +void setup_inference_env(lbann_comm* lc, + int mbs, + std::vector input_dims, + std::vector output_dims); +} + TEST_CASE("Test batch_function_inference_algorithm", "[inference]") { using DataType = float; - DataType one = 1.; - DataType zero = 0.; - int mbs_class_n = 4; + using DMat = + El::DistMatrix; + DataType constexpr one = 1.; + DataType constexpr zero = 0.; + int const mbs_class_n = 4; auto& comm = unit_test::utilities::current_world_comm(); - std::unique_ptr model = make_model(comm, mbs_class_n); auto const& g = comm.get_trainer_grid(); - El::DistMatrix - data(mbs_class_n, mbs_class_n, g); - auto inf_alg = lbann::batch_functional_inference_algorithm(); + auto my_proto = get_model_protobuf(); // the pbuf msg is a global string + + // Construct a trainer so that the model can register the input layer + lbann::construct_trainer(&comm, my_proto.mutable_trainer(), my_proto); + + lbann::setup_inference_env(&comm, mbs_class_n, {mbs_class_n}, {mbs_class_n}); + auto model = make_model(comm, my_proto, mbs_class_n); + auto inf_alg = lbann::batch_functional_inference_algorithm(); SECTION("Model data insert and forward prop") { + DMat data(mbs_class_n, mbs_class_n, g); El::Fill(data, one); + std::map samples; + samples["data/samples"] = std::move(data); - inf_alg.infer(model.get(), data, mbs_class_n); - const auto* l = model->get_layers()[1]; - auto const& dtl = dynamic_cast const&>(*l); - const auto& output = dtl.get_activations(); + lbann::set_inference_samples(samples); + inf_alg.infer(model.get()); - for (int i = 0; i < output.Height(); i++) { - for (int j = 0; j < output.Width(); j++) { + auto const& l = model->get_layer(1); + auto const& dtl = dynamic_cast const&>(l); + auto const& output = dtl.get_activations(); + + for (El::Int i = 0; i < output.Height(); i++) { + for (El::Int j = 0; j < output.Width(); j++) { REQUIRE(output.Get(i, j) == Approx(1.0 / mbs_class_n)); } } @@ -113,10 +136,15 @@ TEST_CASE("Test batch_function_inference_algorithm", "[inference]") SECTION("Verify inference label correctness") { + DMat data(mbs_class_n, mbs_class_n, g); El::Fill(data, zero); El::FillDiagonal(data, one); - auto labels = inf_alg.infer(model.get(), data, mbs_class_n); + std::map samples; + samples["data/samples"] = std::move(data); + lbann::set_inference_samples(samples); + + auto labels = inf_alg.infer(model.get()); for (int i = 0; i < labels.Height(); i++) { REQUIRE(labels(i) == i); diff --git a/src/layers/io/input_layer.cpp b/src/layers/io/input_layer.cpp index c501b8e352a..34289876316 100644 --- a/src/layers/io/input_layer.cpp +++ b/src/layers/io/input_layer.cpp @@ -134,23 +134,21 @@ void input_layer::fp_setup_outputs() template void input_layer::fp_compute() { - if (!this->m_samples_loaded) { - execution_mode const mode = - this->m_model->get_execution_context().get_execution_mode(); - buffered_data_coordinator& dc = - static_cast&>( - get_trainer().get_data_coordinator()); + execution_mode const mode = + this->m_model->get_execution_context().get_execution_mode(); + buffered_data_coordinator& dc = + static_cast&>( + get_trainer().get_data_coordinator()); - dc.distribute_from_local_matrix(mode, - m_data_field, - this->get_activations(0)); + dc.distribute_from_local_matrix(mode, + m_data_field, + this->get_activations(0)); #ifdef LBANN_HAS_DISTCONV - if (this->distconv_enabled()) { - get_distconv_adapter().fp_compute(); - } -#endif // LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + get_distconv_adapter().fp_compute(); } +#endif // LBANN_HAS_DISTCONV } template diff --git a/src/models/unit_test/model_test.cpp b/src/models/unit_test/model_test.cpp index eb3649a4925..e0f4b682e68 100644 --- a/src/models/unit_test/model_test.cpp +++ b/src/models/unit_test/model_test.cpp @@ -48,6 +48,16 @@ namespace { // model_prototext string is defined here as a "const std::string". #include "lenet.prototext.inc" +auto mock_datareader_metadata() +{ + lbann::DataReaderMetaData md; + auto& md_dims = md.data_dims; + // This is all that should be needed for this test. + md_dims[lbann::data_reader_target_mode::CLASSIFICATION] = {10}; + md_dims[lbann::data_reader_target_mode::INPUT] = {1,28,28}; + return md; +} + template auto make_model(lbann::lbann_comm& comm, const std::string& model_contents = model_prototext) @@ -72,8 +82,6 @@ auto make_model(lbann::lbann_comm& comm, using unit_test::utilities::IsValidPtr; TEST_CASE("Serializing models", "[mpi][model][serialize]") { - using DataType = float; - auto& comm = unit_test::utilities::current_world_comm(); auto& g = comm.get_trainer_grid(); diff --git a/src/trainers/trainer.cpp b/src/trainers/trainer.cpp index 336bcfdc58e..c0e8dc61d59 100644 --- a/src/trainers/trainer.cpp +++ b/src/trainers/trainer.cpp @@ -120,9 +120,9 @@ void trainer::setup(std::unique_ptr io_thread_pool, // layer depends on having a properly initialized thread pool) m_io_thread_pool = std::move(io_thread_pool); - m_data_coordinator.get()->setup(*m_io_thread_pool.get(), - get_max_mini_batch_size(), - data_readers); + m_data_coordinator->setup(*m_io_thread_pool.get(), + get_max_mini_batch_size(), + data_readers); for (auto& [mode, reader] : data_readers) { if (!reader->supports_background_io()) { diff --git a/src/utils/lbann_library.cpp b/src/utils/lbann_library.cpp index b5ad85e8a08..8bf1ba0172a 100644 --- a/src/utils/lbann_library.cpp +++ b/src/utils/lbann_library.cpp @@ -46,21 +46,138 @@ #include "lbann/proto/lbann.pb.h" #include "lbann/proto/model.pb.h" +#include + #include #include namespace lbann { -// Loads a model from checkpoint and sets up model for inference -std::unique_ptr load_inference_model(lbann_comm* lc, - std::string cp_dir, - int mbs, - std::vector input_dims, - std::vector output_dims) +namespace { + +std::unique_ptr global_trainer_; + +void cleanup_trainer_atexit() { global_trainer_ = nullptr; } + +} // namespace + +trainer& get_trainer() +{ + LBANN_ASSERT(global_trainer_); + return *global_trainer_; +} +trainer const& get_const_trainer() +{ + LBANN_ASSERT(global_trainer_); + return *global_trainer_; +} + +bool trainer_exists() { + if (global_trainer_ == nullptr) { + return false; + } + else { + return true; + } +} + +void finalize_trainer() { global_trainer_.reset(); } + +// Creates a datareader metadata to get around the need for an actual +// datareader in inference only mode +auto mock_dr_metadata(std::vector input_dims, + std::vector output_dims) { + DataReaderMetaData drmd; + auto& md_dims = drmd.data_dims; + md_dims[data_reader_target_mode::INPUT] = input_dims; + md_dims[data_reader_target_mode::CLASSIFICATION] = output_dims; + return drmd; +} + +// Sets data samples for lbann-core inference +void set_inference_samples(std::vector &samples) { + data_coordinator& dc = global_trainer_->get_data_coordinator(); + generic_data_reader* dr = dc.get_data_reader(execution_mode::inference); + data_store_conduit& ds = dr->get_data_store(); + + // Push samples into data_store + std::vector indices(samples.size()); + std::iota(indices.begin(), indices.end(), 0); + dr->set_shuffled_indices(indices); + int data_id = 0; + for (auto& node : samples) { + ds.set_conduit_node(data_id, node); + ++data_id; + } + + // Call setup on data_coordinator to have proper init of input buffers + std::map data_readers = { + {execution_mode::inference, dr}}; + dc.setup(global_trainer_->get_io_thread_pool(), + global_trainer_->get_max_mini_batch_size(), + data_readers); +} + +void setup_inference_env(lbann_comm* lc, + int mbs, + std::vector input_dims, + std::vector output_dims) +{ + // Setup data reader + auto reader = std::make_unique(); + reader->set_comm(lc); + reader->set_num_parallel_readers(lc->get_procs_per_trainer()); + reader->set_mini_batch_size(mbs); + reader->set_data_dims(input_dims); + reader->set_label_dims(output_dims); + reader->set_shuffle(false); + + // Add data store + reader->set_data_store(new lbann::data_store_conduit(reader.get())); + + // Setup data coordinator and trainer + std::unique_ptr io_thread_pool = + construct_io_thread_pool(lc, false); + auto io_threads_per_process = io_thread_pool->get_num_threads(); + std::unique_ptr dc = + lbann::make_unique>(lc); + global_trainer_ = + lbann::make_unique(lc, std::move(dc), mbs, nullptr); + int root_random_seed = lbann_default_random_seed; + int random_seed = root_random_seed; + int data_seq_random_seed = root_random_seed; + init_random(random_seed, io_threads_per_process); + init_data_seq_random(data_seq_random_seed); + init_ltfb_random(root_random_seed); + global_trainer_->set_random_seeds(root_random_seed, + random_seed, + data_seq_random_seed); + + // FIXME: The global trainer holds a data coordinator that holds + // this map. When the aforementioned data coordinator is destroyed, + // it deletes the data readers that it's holding. Therefore, we + // release the pointer that we hold and hope no exceptions occur + // before the data coordinator takes control of the pointer. + std::map data_readers = { + {execution_mode::inference, reader.release()}}; + global_trainer_->setup(std::move(io_thread_pool), data_readers); +} + +// Loads a model from checkpoint and sets up model for inference +std::unique_ptr +load_inference_model(lbann_comm* lc, + std::string cp_dir, + int mbs, + std::vector input_dims, + std::vector output_dims) { + // Setup trainer, dr, dc, ds + setup_inference_env(lc, mbs, input_dims, output_dims); + + // Load checkpoint persist p; p.open_restart(cp_dir.c_str()); - auto m = std::make_unique(lc, nullptr, nullptr); + auto m = make_unique(lc, nullptr, nullptr); m->load_from_checkpoint_shared(p); p.close_restart(); @@ -69,6 +186,12 @@ std::unique_ptr load_inference_model(lbann_comm* lc, return m; } +El::Matrix +inference(observer_ptr model) { + auto inf_alg = batch_functional_inference_algorithm(); + return inf_alg.infer(model); +} + /// Split the MPI communicator into trainers /// Return the int allocate_trainer_resources(lbann_comm* comm) @@ -109,37 +232,6 @@ int allocate_trainer_resources(lbann_comm* comm) return procs_per_trainer; } -namespace { - -std::unique_ptr global_trainer_; - -void cleanup_trainer_atexit() { global_trainer_ = nullptr; } - -} // namespace - -trainer& get_trainer() -{ - LBANN_ASSERT(global_trainer_); - return *global_trainer_; -} -trainer const& get_const_trainer() -{ - LBANN_ASSERT(global_trainer_); - return *global_trainer_; -} - -bool trainer_exists() -{ - if (global_trainer_ == nullptr) { - return false; - } - else { - return true; - } -} - -void finalize_trainer() { global_trainer_.reset(); } - /// Construct a trainer that contains a lbann comm object and threadpool trainer& construct_trainer(lbann_comm* comm, lbann_data::Trainer* pb_trainer, @@ -601,7 +693,7 @@ void print_lbann_configuration(lbann_comm* comm, #else std::cout << "NOT detected" << std::endl; #endif // LBANN_HAS_ALUMINUM - std::cout << " GPU : "; + std::cout << " GPU : "; #ifdef LBANN_HAS_GPU std::cout << "detected" << std::endl; #else From 5acdc2549031c616beb720962b255c2b087f889d Mon Sep 17 00:00:00 2001 From: Brian Van Essen Date: Tue, 24 Sep 2024 11:17:00 -0700 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Tom Benson --- CMakeLists.txt | 1 - core-driver/CMakeLists.txt | 3 +-- include/lbann/data_ingestion/readers/data_reader_conduit.hpp | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1252a8dc7a0..3030cde812d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -920,7 +920,6 @@ add_subdirectory(applications/CANDLE/pilot2/tools) add_subdirectory(applications/ATOM/utils) add_subdirectory(tests) add_subdirectory(scripts) -add_subdirectory(core-driver) ################################################################ # Install LBANN diff --git a/core-driver/CMakeLists.txt b/core-driver/CMakeLists.txt index 82a8f6a56dd..d60d5106356 100644 --- a/core-driver/CMakeLists.txt +++ b/core-driver/CMakeLists.txt @@ -1,7 +1,6 @@ cmake_minimum_required(VERSION 3.21.0) -project(my_lbann_test CXX) +project(lbann-test-driver CXX) find_package(LBANN 0.102.0 REQUIRED) -find_package(Conduit CONFIG REQUIRED) add_executable(lbann-core main.cpp) target_link_libraries(lbann-core PRIVATE LBANN::lbann) diff --git a/include/lbann/data_ingestion/readers/data_reader_conduit.hpp b/include/lbann/data_ingestion/readers/data_reader_conduit.hpp index 5103ec40cbb..269ad483dea 100644 --- a/include/lbann/data_ingestion/readers/data_reader_conduit.hpp +++ b/include/lbann/data_ingestion/readers/data_reader_conduit.hpp @@ -61,7 +61,7 @@ class conduit_data_reader : public generic_data_reader return label_size; } -protected: +private: std::vector m_data_dims; std::vector m_label_dims; From 6f95b722b8185308dfa93f1075853b0b45086ab8 Mon Sep 17 00:00:00 2001 From: Brian Van Essen Date: Tue, 24 Sep 2024 11:17:24 -0700 Subject: [PATCH 3/3] Update core-driver/CMakeLists.txt Co-authored-by: Tom Benson --- core-driver/CMakeLists.txt | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/core-driver/CMakeLists.txt b/core-driver/CMakeLists.txt index d60d5106356..0eb7bd5a4f4 100644 --- a/core-driver/CMakeLists.txt +++ b/core-driver/CMakeLists.txt @@ -1,17 +1,9 @@ cmake_minimum_required(VERSION 3.21.0) project(lbann-test-driver CXX) find_package(LBANN 0.102.0 REQUIRED) -add_executable(lbann-core main.cpp) -target_link_libraries(lbann-core PRIVATE LBANN::lbann) +add_executable(lbann-test-driver main.cpp) +target_link_libraries(lbann-test-driver PRIVATE LBANN::lbann) -#target_link_libraries(lbann-bin lbann) -set_target_properties(lbann-core +set_target_properties(lbann-test-driver PROPERTIES - OUTPUT_NAME lbann-core-driver RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) - -#list(APPEND LBANN_EXE_TGTS lbann-core) - -install(TARGETS lbann-core - EXPORT LBANNTargets - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})