Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidyn-A committed Nov 5, 2024
1 parent 0ad4df8 commit a7a49d1
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
1 change: 1 addition & 0 deletions gloo/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
find_package(OpenSSL 1.1 REQUIRED EXACT)

set(GLOO_TEST_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc"
Expand Down
145 changes: 145 additions & 0 deletions gloo/test/abort_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <functional>
#include <thread>
#include <vector>

#include "gloo/barrier_all_to_all.h"
#include "gloo/barrier_all_to_one.h"
#include "gloo/broadcast.h"
#include "gloo/test/base_test.h"

namespace gloo {
namespace test {
namespace {

// Function to instantiate and run algorithm.
using Func = void(std::shared_ptr<::gloo::Context>);

// Test parameterization.
using Param = std::tuple<Transport, int, std::function<Func>>;

// Test fixture.
class BarrierTest : public BaseTest,
public ::testing::WithParamInterface<Param> {};

TEST_P(BarrierTest, SinglePointer) {
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());
const auto fn = std::get<2>(GetParam());

spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
fn(context);
});
}

static std::function<Func> barrierAllToAll =
[](std::shared_ptr<::gloo::Context> context) {
::gloo::BarrierAllToAll algorithm(context);
algorithm.run();
};

INSTANTIATE_TEST_CASE_P(
BarrierAllToAll,
BarrierTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForClassAlgorithms),
::testing::Range(2, 16),
::testing::Values(barrierAllToAll)));

static std::function<Func> barrierAllToOne =
[](std::shared_ptr<::gloo::Context> context) {
::gloo::BarrierAllToOne algorithm(context);
algorithm.run();
};

INSTANTIATE_TEST_CASE_P(
BarrierAllToOne,
BarrierTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForClassAlgorithms),
::testing::Range(2, 16),
::testing::Values(barrierAllToOne)));

// Synchronized version of std::chrono::clock::now().
// All processes participating in the specified context will
// see the same value.
template <typename clock>
std::chrono::time_point<clock> syncNow(std::shared_ptr<Context> context) {
const typename clock::time_point now = clock::now();
typename clock::duration::rep count = now.time_since_epoch().count();
BroadcastOptions opts(context);
opts.setRoot(0);
opts.setOutput(&count, 1);
broadcast(opts);
return typename clock::time_point(typename clock::duration(count));
}

using NewParam = std::tuple<Transport, int>;

class BarrierNewTest : public BaseTest,
public ::testing::WithParamInterface<NewParam> {};

TEST_P(BarrierNewTest, Default) {
const auto transport = std::get<0>(GetParam());
const auto contextSize = std::get<1>(GetParam());

spawn(transport, contextSize, [&](std::shared_ptr<Context> context) {
BarrierOptions opts(context);

// Run barrier to synchronize processes after starting.
barrier(opts);

// Take turns in sleeping for a bit and checking that all processes
// saw that artificial delay through the barrier.
auto singleProcessDelay = std::chrono::milliseconds(1000);
for (size_t i = 0; i < context->size; i++) {
const auto start = syncNow<std::chrono::high_resolution_clock>(context);
if (i == context->rank) {
/* sleep override */
std::this_thread::sleep_for(singleProcessDelay);
}

barrier(opts);
abort();

// Expect all processes to have taken less than the sleep, as abort was called
auto stop = std::chrono::high_resolution_clock::now();
auto delta = std::chrono::duration_cast<decltype(singleProcessDelay)>(
stop - start);
ASSERT_LE(delta.count(), singleProcessDelay.count());
}
});
}

INSTANTIATE_TEST_CASE_P(
BarrierNewDefault,
BarrierNewTest,
::testing::Combine(
::testing::ValuesIn(kTransportsForFunctionAlgorithms),
::testing::Values(1, 2, 4, 7)));

TEST_F(BarrierNewTest, TestTimeout) {
spawn(Transport::TCP, 2, [&](std::shared_ptr<Context> context) {
BarrierOptions opts(context);
opts.setTimeout(std::chrono::milliseconds(10));
if (context->rank == 0) {
try {
barrier(opts);
FAIL() << "Expected exception to be thrown";
} catch (::gloo::IoException& e) {
ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos);
}
}
});
}

} // namespace
} // namespace test
} // namespace gloo

0 comments on commit a7a49d1

Please sign in to comment.