Skip to content

Commit

Permalink
add a repeat_n algorithm; fix gcc TSAN errors (NVIDIA#1202)
Browse files Browse the repository at this point in the history
add a `repeat_n` algorithm; fix gcc TSAN errors

---------

Co-authored-by: Paul Taylor <[email protected]>
  • Loading branch information
ericniebler and trxcllnt authored Jan 9, 2024
1 parent 6306beb commit 98f1f7a
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 13 deletions.
37 changes: 24 additions & 13 deletions include/exec/repeat_effect_until.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
#pragma once

#include "../stdexec/execution.hpp"
#include "exec/trampoline_scheduler.hpp"
#include "exec/on.hpp"
#include "__detail/__manual_lifetime.hpp"
#include "stdexec/__detail/__meta.hpp"
#include "stdexec/__detail/__basic_sender.hpp"
#include "stdexec/concepts.hpp"
#include "stdexec/functional.hpp"
#include "../stdexec/concepts.hpp"
#include "../stdexec/functional.hpp"
#include "../stdexec/__detail/__meta.hpp"
#include "../stdexec/__detail/__basic_sender.hpp"

#include "on.hpp"
#include "trampoline_scheduler.hpp"
#include "__detail/__manual_lifetime.hpp"

#include <atomic>
#include <concepts>

namespace exec {
Expand Down Expand Up @@ -55,6 +57,9 @@ namespace exec {
};
};

STDEXEC_PRAGMA_PUSH()
STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")

template <class _Sender, class _Receiver>
struct __repeat_effect_state : stdexec::__enable_receiver_from_this<_Sender, _Receiver> {
using __child_t = __decay_t<__data_of<_Sender>>;
Expand All @@ -63,7 +68,7 @@ namespace exec {
using __child_op_t = stdexec::connect_result_t<__child_on_sched_sender_t, __receiver_t>;

__child_t __child_;
bool __started_ = false;
std::atomic_flag __started_{};
__manual_lifetime<__child_op_t> __child_op_;
trampoline_scheduler __sched_;

Expand All @@ -73,7 +78,11 @@ namespace exec {
}

~__repeat_effect_state() {
if (!__started_) {
if (!__started_.test(std::memory_order_acquire)) {
std::atomic_thread_fence(std::memory_order_release);
// TSan does not support std::atomic_thread_fence, so we
// need to use the TSan-specific __tsan_release instead:
STDEXEC_TSAN(__tsan_release(&__started_));
__child_op_.__destroy();
}
}
Expand All @@ -85,9 +94,10 @@ namespace exec {
}

void __start() noexcept {
STDEXEC_ASSERT(!__started_);
const bool __already_started [[maybe_unused]] =
__started_.test_and_set(std::memory_order_relaxed);
STDEXEC_ASSERT(!__already_started);
stdexec::start(__child_op_.__get());
__started_ = true;
}

template <class _Tag, class... _Args>
Expand All @@ -112,6 +122,8 @@ namespace exec {
}
};

STDEXEC_PRAGMA_POP()

template <
__mstring _Where = "In repeat_effect_until: "__csz,
__mstring _What = "The input sender must send a single value that is convertible to bool"__csz>
Expand Down Expand Up @@ -175,8 +187,7 @@ namespace exec {
});
}
};

} // namespace __repeat_effect
} // namespace __repeat_effect_until

using __repeat_effect_until::repeat_effect_until_t;
inline constexpr repeat_effect_until_t repeat_effect_until{};
Expand Down
215 changes: 215 additions & 0 deletions include/exec/repeat_n.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright (c) 2023 Runner-2019
* Copyright (c) 2023 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "../stdexec/execution.hpp"
#include "../stdexec/concepts.hpp"
#include "../stdexec/functional.hpp"
#include "../stdexec/__detail/__meta.hpp"
#include "../stdexec/__detail/__basic_sender.hpp"

#include "on.hpp"
#include "trampoline_scheduler.hpp"
#include "__detail/__manual_lifetime.hpp"

#include <atomic>
#include <concepts>

namespace exec {
namespace __repeat_n {
using namespace stdexec;

template <class _Sender, class _Receiver>
struct __repeat_n_state;

template <class _SenderId, class _ReceiverId>
struct __receiver {
using _Sender = stdexec::__t<_SenderId>;
using _Receiver = stdexec::__t<_ReceiverId>;

struct __t {
using __id = __receiver;
using receiver_concept = stdexec::receiver_t;
__repeat_n_state<_Sender, _Receiver> *__state_;

template <__completion_tag _Tag, class... _Args>
friend void tag_invoke(_Tag, __t &&__self, _Args &&...__args) noexcept {
__self.__state_->__complete(_Tag(), (_Args &&) __args...);
}

friend env_of_t<_Receiver> tag_invoke(get_env_t, const __t &__self) noexcept {
return get_env(__self.__state_->__receiver());
}
};
};

template <class _Child>
struct __child_count_pair {
_Child __child_;
std::size_t __count_;
};

template <class _Child>
__child_count_pair(_Child, std::size_t) -> __child_count_pair<_Child>;

STDEXEC_PRAGMA_PUSH()
STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan")

template <class _Sender, class _Receiver>
struct __repeat_n_state : stdexec::__enable_receiver_from_this<_Sender, _Receiver> {
using __child_count_pair_t = __decay_t<__data_of<_Sender>>;
using __child_t = decltype(__child_count_pair_t::__child_);
using __receiver_t = stdexec::__t<__receiver<__id<_Sender>, __id<_Receiver>>>;
using __child_on_sched_sender_t = __result_of<stdexec::on, trampoline_scheduler, __child_t &>;
using __child_op_t = stdexec::connect_result_t<__child_on_sched_sender_t, __receiver_t>;

__child_count_pair<__child_t> __pair_;
std::atomic_flag __started_{};
__manual_lifetime<__child_op_t> __child_op_;
trampoline_scheduler __sched_;

__repeat_n_state(_Sender &&__sndr, _Receiver &)
: __pair_(__sexpr_apply((_Sender &&) __sndr, __detail::__get_data())) {
// Q: should we skip __connect() if __count_ == 0?
__connect();
}

~__repeat_n_state() {
if (!__started_.test(std::memory_order_acquire)) {
std::atomic_thread_fence(std::memory_order_release);
// TSan does not support std::atomic_thread_fence, so we
// need to use the TSan-specific __tsan_release instead:
STDEXEC_TSAN(__tsan_release(&__started_));
__child_op_.__destroy();
}
}

void __connect() {
__child_op_.__construct_with([this] {
return stdexec::connect(stdexec::on(__sched_, __pair_.__child_), __receiver_t{this});
});
}

void __start() noexcept {
if (__pair_.__count_ == 0) {
stdexec::set_value((_Receiver &&) this->__receiver());
} else {
const bool __already_started [[maybe_unused]] =
__started_.test_and_set(std::memory_order_relaxed);
STDEXEC_ASSERT(!__already_started);
stdexec::start(__child_op_.__get());
}
}

template <class _Tag, class... _Args>
void __complete(_Tag, _Args &&...__args) noexcept {
STDEXEC_ASSERT(__pair_.__count_ > 0);
__child_op_.__destroy();
if constexpr (same_as<_Tag, set_value_t>) {
try {
if (--__pair_.__count_ == 0) {
stdexec::set_value((_Receiver &&) this->__receiver());
} else {
__connect();
stdexec::start(__child_op_.__get());
}
} catch (...) {
stdexec::set_error((_Receiver &&) this->__receiver(), std::current_exception());
}
} else {
_Tag()((_Receiver &&) this->__receiver(), (_Args &&) __args...);
}
}
};

STDEXEC_PRAGMA_POP()

template <
__mstring _Where = "In repeat_n: "__csz,
__mstring _What = "The input sender must be a sender of void"__csz>
struct _INVALID_ARGUMENT_TO_REPEAT_N_ { };

template <class _Sender, class... _Args>
using __values_t = //
// There's something funny going on with __if_c here. Use std::conditional_t instead. :-(
std::conditional_t<
(sizeof...(_Args) == 0),
completion_signatures<>,
__mexception<_INVALID_ARGUMENT_TO_REPEAT_N_<>, _WITH_SENDER_<_Sender>>>;

template <class _Pair, class _Env>
using __completions_t = //
stdexec::__try_make_completion_signatures<
decltype(__decay_t<_Pair>::__child_) &,
_Env,
stdexec::__try_make_completion_signatures<
stdexec::schedule_result_t<exec::trampoline_scheduler>,
_Env,
__with_exception_ptr>,
__mbind_front_q<__values_t, decltype(__decay_t<_Pair>::__child_)>>;

struct __repeat_n_tag { };

struct __repeat_n_impl : __sexpr_defaults {
static constexpr auto get_completion_signatures = //
[]<class _Sender, class _Env>(_Sender &&, _Env &&) noexcept {
return __completions_t<__data_of<_Sender>, _Env>{};
};

static constexpr auto get_state = //
[]<class _Sender, class _Receiver>(_Sender &&__sndr, _Receiver &__rcvr) {
return __repeat_n_state{std::move(__sndr), __rcvr};
};

static constexpr auto start = //
[](auto &__state, __ignore) noexcept -> void {
__state.__start();
};
};

struct repeat_n_t {
template <sender _Sender>
auto operator()(_Sender &&__sndr, std::size_t __count) const {
auto __domain = __get_early_domain(__sndr);
return stdexec::transform_sender(
__domain, __make_sexpr<repeat_n_t>(__count, (_Sender &&) __sndr));
}

constexpr auto operator()(std::size_t __count) const
-> __binder_back<repeat_n_t, std::size_t> {
return {{}, {}, {__count}};
}

template <class _Sender>
auto transform_sender(_Sender &&__sndr, __ignore) {
return __sexpr_apply(
(_Sender &&) __sndr, []<class _Child>(__ignore, std::size_t __count, _Child __child) {
return __make_sexpr<__repeat_n_tag>(__child_count_pair{std::move(__child), __count});
});
}
};
} // namespace __repeat_n

using __repeat_n::repeat_n_t;
inline constexpr repeat_n_t repeat_n{};
} // namespace exec

namespace stdexec {
template <>
struct __sexpr_impl<exec::__repeat_n::__repeat_n_tag>
: exec::__repeat_n::__repeat_n_impl { };
}
9 changes: 9 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ set(stdexec_test_sources
exec/test_on2.cpp
exec/test_on3.cpp
exec/test_repeat_effect_until.cpp
exec/test_repeat_n.cpp
exec/async_scope/test_dtor.cpp
exec/async_scope/test_spawn.cpp
exec/async_scope/test_spawn_future.cpp
Expand Down Expand Up @@ -155,6 +156,14 @@ icm_add_build_failure_test(
FOLDER test
)

icm_add_build_failure_test(
NAME test_repeat_n_fail
TARGET test_repeat_n_fail
SOURCES PARSE exec/test_repeat_n_fail.cpp
LIBRARIES stdexec
FOLDER test
)

# # Adding multiple tests with a glob
# icm_glob_build_failure_tests(
# PATTERN *_fail*.cpp
Expand Down
Loading

0 comments on commit 98f1f7a

Please sign in to comment.