Skip to content

Commit

Permalink
WaitPermute only allow move
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanZyne committed Nov 19, 2024
1 parent 4520d97 commit d9eb123
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/idtr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
if (isStrided) {
unpack(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
oDataPtr);
delete[](char *) rBuff;
delete[] (char *)rBuff;
}
};
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
Expand Down Expand Up @@ -735,18 +735,28 @@ template <typename T> class WaitPermute {
SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
std::vector<Parts> &&parts, std::vector<int64_t> &&axes,
std::vector<int64_t> oGShape, ndarray<T> &&input,
ndarray<T> &&output, std::vector<T> &&receiveBuffer,
std::vector<int> &&receiveOffsets,
ndarray<T> &&output, std::vector<T> &&sendBuffer,
std::vector<int> &&sendOffsets, std::vector<int> &&sendSizes,
std::vector<T> &&receiveBuffer, std::vector<int> &&receiveOffsets,
std::vector<int> &&receiveSizes)
: tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
axes(std::move(axes)), oGShape(std::move(oGShape)),
input(std::move(input)), output(std::move(output)),
sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)),
sendSizes(std::move(sendSizes)),
receiveBuffer(std::move(receiveBuffer)),
receiveOffsets(std::move(receiveOffsets)),
receiveSizes(std::move(receiveSizes)) {}

// Only allow move
WaitPermute(const WaitPermute &) = delete;
WaitPermute &operator=(const WaitPermute &) = delete;
WaitPermute(WaitPermute &&) = default;
WaitPermute &operator=(WaitPermute &&) = default;

void operator()() {
tc->wait(hdl);

std::vector<std::vector<T>> receiveRankBuffer(nRanks);
for (size_t rank = 0; rank < nRanks; ++rank) {
auto &rankBuffer = receiveRankBuffer[rank];
Expand All @@ -755,6 +765,7 @@ template <typename T> class WaitPermute {
receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]);
}

// FIXME: very low efficiency, need to improve
std::vector<size_t> receiveRankBufferCount(nRanks, 0);
input.globalIndices([&](const id &inputIndex) {
id outputIndex = inputIndex.permute(axes);
Expand All @@ -777,6 +788,9 @@ template <typename T> class WaitPermute {
std::vector<int64_t> oGShape;
ndarray<T> input;
ndarray<T> output;
std::vector<T> sendBuffer;
std::vector<int> sendOffsets;
std::vector<int> sendSizes;
std::vector<T> receiveBuffer;
std::vector<int> receiveOffsets;
std::vector<int> receiveSizes;
Expand Down Expand Up @@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
for (auto i = 0ul; i < nRanks; ++i) {
dspl[i] = 4 * i;
}

tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64,
SHARPY::REPLICATED);

Expand Down Expand Up @@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
sendOffsets.data(), sharpytype, receiveBuffer.data(),
receiveSizes.data(), receiveOffsets.data());

auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts),
std::move(axes), std::move(oGShape), std::move(input),
std::move(output), std::move(receiveBuffer),
std::move(receiveOffsets), std::move(receiveSizes));
auto wait =
WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), std::move(axes),
std::move(oGShape), std::move(input), std::move(output),
std::move(sendBuffer), std::move(sendOffsets),
std::move(sendSizes), std::move(receiveBuffer),
std::move(receiveOffsets), std::move(receiveSizes));

assert(parts.empty() && axes.empty() && receiveBuffer.empty() &&
receiveOffsets.empty() && receiveSizes.empty());
Expand Down

0 comments on commit d9eb123

Please sign in to comment.