diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index a00bcb0fa3b52..03b4d390f5b76 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/Exchange.h" #include "velox/exec/OutputBufferManager.h" @@ -59,12 +61,19 @@ class ExchangeClientTest : public testing::Test, std::shared_ptr makeTask( const std::string& taskId, - const core::PlanNodePtr& planNode) { - auto queryCtx = std::make_shared(executor_.get()); + const std::optional maxOutputBufferSizeInBytes = {}) { + std::unordered_map config; + if (maxOutputBufferSizeInBytes.has_value()) { + config[core::QueryConfig::kMaxOutputBufferSize] = + std::to_string(maxOutputBufferSizeInBytes.value()); + } + auto queryCtx = std::make_shared( + executor_.get(), core::QueryConfig{std::move(config)}); queryCtx->testingOverrideMemoryPool( memory::memoryManager()->addRootPool(queryCtx->queryId())); + auto plan = test::PlanBuilder().values({}).planNode(); return Task::create( - taskId, core::PlanFragment{planNode}, 0, std::move(queryCtx)); + taskId, core::PlanFragment{plan}, 0, std::move(queryCtx)); } int32_t enqueue( @@ -125,7 +134,7 @@ class ExchangeClientTest : public testing::Test, static std::unique_ptr makePage(uint64_t size) { auto ioBuf = folly::IOBuf::create(size); ioBuf->append(size); - return std::make_unique(std::move(ioBuf)); + return std::make_unique(std::move(ioBuf), nullptr, 1); } folly::Executor* executor() const { @@ -166,12 +175,8 @@ TEST_F(ExchangeClientTest, stats) { makeRowVector({makeFlatVector({1, 2})}), }; - auto plan = test::PlanBuilder() - .values(data) - .partitionedOutput({"c0"}, 100) - .planNode(); auto taskId = "local://t1"; - auto task = makeTask(taskId, plan); + auto task = makeTask(taskId); bufferManager_->initializeTask( task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); @@ -219,15 +224,11 @@ TEST_F(ExchangeClientTest, flowControl) { auto client = std::make_shared( "flow.control", 17, page->size() * 3.5, pool(), executor()); - auto plan = test::PlanBuilder() - .values({data}) - .partitionedOutput({"c0"}, 100) - .planNode(); // Make 10 tasks. std::vector> tasks; for (auto i = 0; i < 10; ++i) { auto taskId = fmt::format("local://t{}", i); - auto task = makeTask(taskId, plan); + auto task = makeTask(taskId); bufferManager_->initializeTask( task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); @@ -263,11 +264,7 @@ TEST_F(ExchangeClientTest, largeSinglePage) { }; auto client = std::make_shared("test", 1, 1000, pool(), executor()); - auto plan = test::PlanBuilder() - .values({data}) - .partitionedOutputArbitrary() - .planNode(); - auto task = makeTask("local://producer", plan); + auto task = makeTask("local://producer"); bufferManager_->initializeTask( task, core::PartitionedOutputNode::Kind::kArbitrary, 1, 1); for (auto& batch : data) { @@ -419,5 +416,115 @@ TEST_F(ExchangeClientTest, sourceTimeout) { test::testingShutdownLocalExchangeSource(); } +TEST_F(ExchangeClientTest, acknowledge) { + const int64_t pageSize = 1024; + const int64_t clientBufferSize = pageSize; + const int64_t serverBufferSize = 2 * pageSize; + + const auto sourceTaskId = "local://test-acknowledge-source-task"; + const auto task = makeTask(sourceTaskId, serverBufferSize); + auto taskRemoveGuard = + folly::makeGuard([bufferManager = bufferManager_, task]() { + task->requestCancel(); + bufferManager->removeTask(task->taskId()); + }); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 2, 1); + + auto client = std::make_shared( + "local://test-acknowledge-client-task", + 1, + clientBufferSize, + pool(), + executor()); + auto clientCloseGuard = folly::makeGuard([client]() { client->close(); }); + + client->addRemoteTaskId(sourceTaskId); + client->noMoreRemoteTasks(); + + { + // adding the first page should not block as there is enough space in + // the output buffer for two pages + ContinueFuture future; + bufferManager_->enqueue(sourceTaskId, 1, makePage(pageSize), &future); + ASSERT_TRUE(future.isReady()); + } + + { + // adding the second page may block but will get unblocked once the + // client fetches a single page + ContinueFuture future; + bufferManager_->enqueue(sourceTaskId, 1, makePage(pageSize), &future); + ASSERT_NO_THROW( + std::move(future).via(executor()).wait(std::chrono::seconds{10}).get()); + } + + { + // adding the third page should block (one page is in the exchange queue, + // another two pages are in the output buffer) + ContinueFuture enqueueDetachedFuture; + bufferManager_->enqueue( + sourceTaskId, 1, makePage(pageSize), &enqueueDetachedFuture); + ASSERT_FALSE(enqueueDetachedFuture.isReady()); + + auto enqueueFuture = std::move(enqueueDetachedFuture) + .via(executor()) + .wait(std::chrono::milliseconds{100}); + ASSERT_FALSE(enqueueFuture.isReady()); + + // removing one page from the exchange queue should trigger a fetch and + // a subsequent acknowledge to release the output buffer memory + bool atEnd; + ContinueFuture dequeueDetachedFuture; + auto pages = client->next(1, &atEnd, &dequeueDetachedFuture); + ASSERT_EQ(1, pages.size()); + ASSERT_FALSE(atEnd); + ASSERT_TRUE(dequeueDetachedFuture.isReady()); + + ASSERT_NO_THROW( + std::move(enqueueFuture).wait(std::chrono::seconds{10}).value()); + } + + // one page is still in the buffer at this point + ASSERT_EQ(bufferManager_->getUtilization(sourceTaskId), 0.5); + + auto pages = fetchPages(*client, 1); + ASSERT_EQ(1, pages.size()); + + { + // at this point the output buffer is expected to be empty + int attempts = 10; + bool outputBuffersEmpty; + while (attempts > 0) { + attempts--; + outputBuffersEmpty = bufferManager_->getUtilization(sourceTaskId) == 0; + if (outputBuffersEmpty) { + break; + } + std::this_thread::sleep_for(std::chrono::seconds{1}); + } + ASSERT_TRUE(outputBuffersEmpty); + } + + pages = fetchPages(*client, 1); + ASSERT_EQ(1, pages.size()); + + bufferManager_->noMoreData(sourceTaskId); + + bool atEnd; + ContinueFuture dequeueEndOfDataFuture; + pages = client->next(1, &atEnd, &dequeueEndOfDataFuture); + ASSERT_EQ(0, pages.size()); + + ASSERT_NO_THROW(std::move(dequeueEndOfDataFuture) + .via(executor()) + .wait(std::chrono::seconds{10}) + .value()); + pages = client->next(1, &atEnd, &dequeueEndOfDataFuture); + ASSERT_EQ(0, pages.size()); + ASSERT_TRUE(atEnd); +} + } // namespace } // namespace facebook::velox::exec