From a7a49d18383af95051cf9d128be6ff0f79e4fe92 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Tue, 5 Nov 2024 17:25:44 +0400 Subject: [PATCH] add test --- gloo/test/CMakeLists.txt | 1 + gloo/test/abort_test.cc | 145 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 gloo/test/abort_test.cc diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index 743e089ee..a0e060457 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -1,6 +1,7 @@ find_package(OpenSSL 1.1 REQUIRED EXACT) set(GLOO_TEST_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc" diff --git a/gloo/test/abort_test.cc b/gloo/test/abort_test.cc new file mode 100644 index 000000000..aec1ae35f --- /dev/null +++ b/gloo/test/abort_test.cc @@ -0,0 +1,145 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include "gloo/barrier_all_to_all.h" +#include "gloo/barrier_all_to_one.h" +#include "gloo/broadcast.h" +#include "gloo/test/base_test.h" + +namespace gloo { +namespace test { +namespace { + +// Function to instantiate and run algorithm. +using Func = void(std::shared_ptr<::gloo::Context>); + +// Test parameterization. +using Param = std::tuple>; + +// Test fixture. +class BarrierTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(BarrierTest, SinglePointer) { + const auto transport = std::get<0>(GetParam()); + const auto contextSize = std::get<1>(GetParam()); + const auto fn = std::get<2>(GetParam()); + + spawn(transport, contextSize, [&](std::shared_ptr context) { + fn(context); + }); +} + +static std::function barrierAllToAll = + [](std::shared_ptr<::gloo::Context> context) { + ::gloo::BarrierAllToAll algorithm(context); + algorithm.run(); + }; + +INSTANTIATE_TEST_CASE_P( + BarrierAllToAll, + BarrierTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForClassAlgorithms), + ::testing::Range(2, 16), + ::testing::Values(barrierAllToAll))); + +static std::function barrierAllToOne = + [](std::shared_ptr<::gloo::Context> context) { + ::gloo::BarrierAllToOne algorithm(context); + algorithm.run(); + }; + +INSTANTIATE_TEST_CASE_P( + BarrierAllToOne, + BarrierTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForClassAlgorithms), + ::testing::Range(2, 16), + ::testing::Values(barrierAllToOne))); + +// Synchronized version of std::chrono::clock::now(). +// All processes participating in the specified context will +// see the same value. +template +std::chrono::time_point syncNow(std::shared_ptr context) { + const typename clock::time_point now = clock::now(); + typename clock::duration::rep count = now.time_since_epoch().count(); + BroadcastOptions opts(context); + opts.setRoot(0); + opts.setOutput(&count, 1); + broadcast(opts); + return typename clock::time_point(typename clock::duration(count)); +} + +using NewParam = std::tuple; + +class BarrierNewTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(BarrierNewTest, Default) { + const auto transport = std::get<0>(GetParam()); + const auto contextSize = std::get<1>(GetParam()); + + spawn(transport, contextSize, [&](std::shared_ptr context) { + BarrierOptions opts(context); + + // Run barrier to synchronize processes after starting. + barrier(opts); + + // Take turns in sleeping for a bit and checking that all processes + // saw that artificial delay through the barrier. + auto singleProcessDelay = std::chrono::milliseconds(1000); + for (size_t i = 0; i < context->size; i++) { + const auto start = syncNow(context); + if (i == context->rank) { + /* sleep override */ + std::this_thread::sleep_for(singleProcessDelay); + } + + barrier(opts); + abort(); + + // Expect all processes to have taken less than the sleep, as abort was called + auto stop = std::chrono::high_resolution_clock::now(); + auto delta = std::chrono::duration_cast( + stop - start); + ASSERT_LE(delta.count(), singleProcessDelay.count()); + } + }); +} + +INSTANTIATE_TEST_CASE_P( + BarrierNewDefault, + BarrierNewTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForFunctionAlgorithms), + ::testing::Values(1, 2, 4, 7))); + +TEST_F(BarrierNewTest, TestTimeout) { + spawn(Transport::TCP, 2, [&](std::shared_ptr context) { + BarrierOptions opts(context); + opts.setTimeout(std::chrono::milliseconds(10)); + if (context->rank == 0) { + try { + barrier(opts); + FAIL() << "Expected exception to be thrown"; + } catch (::gloo::IoException& e) { + ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos); + } + } + }); +} + +} // namespace +} // namespace test +} // namespace gloo