diff --git a/include/exec/repeat_effect_until.hpp b/include/exec/repeat_effect_until.hpp index 2b256a72e..c88d4d5c9 100644 --- a/include/exec/repeat_effect_until.hpp +++ b/include/exec/repeat_effect_until.hpp @@ -17,14 +17,16 @@ #pragma once #include "../stdexec/execution.hpp" -#include "exec/trampoline_scheduler.hpp" -#include "exec/on.hpp" -#include "__detail/__manual_lifetime.hpp" -#include "stdexec/__detail/__meta.hpp" -#include "stdexec/__detail/__basic_sender.hpp" -#include "stdexec/concepts.hpp" -#include "stdexec/functional.hpp" +#include "../stdexec/concepts.hpp" +#include "../stdexec/functional.hpp" +#include "../stdexec/__detail/__meta.hpp" +#include "../stdexec/__detail/__basic_sender.hpp" + +#include "on.hpp" #include "trampoline_scheduler.hpp" +#include "__detail/__manual_lifetime.hpp" + +#include #include namespace exec { @@ -55,6 +57,9 @@ namespace exec { }; }; + STDEXEC_PRAGMA_PUSH() + STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan") + template struct __repeat_effect_state : stdexec::__enable_receiver_from_this<_Sender, _Receiver> { using __child_t = __decay_t<__data_of<_Sender>>; @@ -63,7 +68,7 @@ namespace exec { using __child_op_t = stdexec::connect_result_t<__child_on_sched_sender_t, __receiver_t>; __child_t __child_; - bool __started_ = false; + std::atomic_flag __started_{}; __manual_lifetime<__child_op_t> __child_op_; trampoline_scheduler __sched_; @@ -73,7 +78,11 @@ namespace exec { } ~__repeat_effect_state() { - if (!__started_) { + if (!__started_.test(std::memory_order_acquire)) { + std::atomic_thread_fence(std::memory_order_release); + // TSan does not support std::atomic_thread_fence, so we + // need to use the TSan-specific __tsan_release instead: + STDEXEC_TSAN(__tsan_release(&__started_)); __child_op_.__destroy(); } } @@ -85,9 +94,10 @@ namespace exec { } void __start() noexcept { - STDEXEC_ASSERT(!__started_); + const bool __already_started [[maybe_unused]] = + __started_.test_and_set(std::memory_order_relaxed); + STDEXEC_ASSERT(!__already_started); stdexec::start(__child_op_.__get()); - __started_ = true; } template @@ -112,6 +122,8 @@ namespace exec { } }; + STDEXEC_PRAGMA_POP() + template < __mstring _Where = "In repeat_effect_until: "__csz, __mstring _What = "The input sender must send a single value that is convertible to bool"__csz> @@ -175,8 +187,7 @@ namespace exec { }); } }; - - } // namespace __repeat_effect + } // namespace __repeat_effect_until using __repeat_effect_until::repeat_effect_until_t; inline constexpr repeat_effect_until_t repeat_effect_until{}; diff --git a/include/exec/repeat_n.hpp b/include/exec/repeat_n.hpp new file mode 100644 index 000000000..a4ea5be4e --- /dev/null +++ b/include/exec/repeat_n.hpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2023 Runner-2019 + * Copyright (c) 2023 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../stdexec/execution.hpp" +#include "../stdexec/concepts.hpp" +#include "../stdexec/functional.hpp" +#include "../stdexec/__detail/__meta.hpp" +#include "../stdexec/__detail/__basic_sender.hpp" + +#include "on.hpp" +#include "trampoline_scheduler.hpp" +#include "__detail/__manual_lifetime.hpp" + +#include +#include + +namespace exec { + namespace __repeat_n { + using namespace stdexec; + + template + struct __repeat_n_state; + + template + struct __receiver { + using _Sender = stdexec::__t<_SenderId>; + using _Receiver = stdexec::__t<_ReceiverId>; + + struct __t { + using __id = __receiver; + using receiver_concept = stdexec::receiver_t; + __repeat_n_state<_Sender, _Receiver> *__state_; + + template <__completion_tag _Tag, class... _Args> + friend void tag_invoke(_Tag, __t &&__self, _Args &&...__args) noexcept { + __self.__state_->__complete(_Tag(), (_Args &&) __args...); + } + + friend env_of_t<_Receiver> tag_invoke(get_env_t, const __t &__self) noexcept { + return get_env(__self.__state_->__receiver()); + } + }; + }; + + template + struct __child_count_pair { + _Child __child_; + std::size_t __count_; + }; + + template + __child_count_pair(_Child, std::size_t) -> __child_count_pair<_Child>; + + STDEXEC_PRAGMA_PUSH() + STDEXEC_PRAGMA_IGNORE_GNU("-Wtsan") + + template + struct __repeat_n_state : stdexec::__enable_receiver_from_this<_Sender, _Receiver> { + using __child_count_pair_t = __decay_t<__data_of<_Sender>>; + using __child_t = decltype(__child_count_pair_t::__child_); + using __receiver_t = stdexec::__t<__receiver<__id<_Sender>, __id<_Receiver>>>; + using __child_on_sched_sender_t = __result_of; + using __child_op_t = stdexec::connect_result_t<__child_on_sched_sender_t, __receiver_t>; + + __child_count_pair<__child_t> __pair_; + std::atomic_flag __started_{}; + __manual_lifetime<__child_op_t> __child_op_; + trampoline_scheduler __sched_; + + __repeat_n_state(_Sender &&__sndr, _Receiver &) + : __pair_(__sexpr_apply((_Sender &&) __sndr, __detail::__get_data())) { + // Q: should we skip __connect() if __count_ == 0? + __connect(); + } + + ~__repeat_n_state() { + if (!__started_.test(std::memory_order_acquire)) { + std::atomic_thread_fence(std::memory_order_release); + // TSan does not support std::atomic_thread_fence, so we + // need to use the TSan-specific __tsan_release instead: + STDEXEC_TSAN(__tsan_release(&__started_)); + __child_op_.__destroy(); + } + } + + void __connect() { + __child_op_.__construct_with([this] { + return stdexec::connect(stdexec::on(__sched_, __pair_.__child_), __receiver_t{this}); + }); + } + + void __start() noexcept { + if (__pair_.__count_ == 0) { + stdexec::set_value((_Receiver &&) this->__receiver()); + } else { + const bool __already_started [[maybe_unused]] = + __started_.test_and_set(std::memory_order_relaxed); + STDEXEC_ASSERT(!__already_started); + stdexec::start(__child_op_.__get()); + } + } + + template + void __complete(_Tag, _Args &&...__args) noexcept { + STDEXEC_ASSERT(__pair_.__count_ > 0); + __child_op_.__destroy(); + if constexpr (same_as<_Tag, set_value_t>) { + try { + if (--__pair_.__count_ == 0) { + stdexec::set_value((_Receiver &&) this->__receiver()); + } else { + __connect(); + stdexec::start(__child_op_.__get()); + } + } catch (...) { + stdexec::set_error((_Receiver &&) this->__receiver(), std::current_exception()); + } + } else { + _Tag()((_Receiver &&) this->__receiver(), (_Args &&) __args...); + } + } + }; + + STDEXEC_PRAGMA_POP() + + template < + __mstring _Where = "In repeat_n: "__csz, + __mstring _What = "The input sender must be a sender of void"__csz> + struct _INVALID_ARGUMENT_TO_REPEAT_N_ { }; + + template + using __values_t = // + // There's something funny going on with __if_c here. Use std::conditional_t instead. :-( + std::conditional_t< + (sizeof...(_Args) == 0), + completion_signatures<>, + __mexception<_INVALID_ARGUMENT_TO_REPEAT_N_<>, _WITH_SENDER_<_Sender>>>; + + template + using __completions_t = // + stdexec::__try_make_completion_signatures< + decltype(__decay_t<_Pair>::__child_) &, + _Env, + stdexec::__try_make_completion_signatures< + stdexec::schedule_result_t, + _Env, + __with_exception_ptr>, + __mbind_front_q<__values_t, decltype(__decay_t<_Pair>::__child_)>>; + + struct __repeat_n_tag { }; + + struct __repeat_n_impl : __sexpr_defaults { + static constexpr auto get_completion_signatures = // + [](_Sender &&, _Env &&) noexcept { + return __completions_t<__data_of<_Sender>, _Env>{}; + }; + + static constexpr auto get_state = // + [](_Sender &&__sndr, _Receiver &__rcvr) { + return __repeat_n_state{std::move(__sndr), __rcvr}; + }; + + static constexpr auto start = // + [](auto &__state, __ignore) noexcept -> void { + __state.__start(); + }; + }; + + struct repeat_n_t { + template + auto operator()(_Sender &&__sndr, std::size_t __count) const { + auto __domain = __get_early_domain(__sndr); + return stdexec::transform_sender( + __domain, __make_sexpr(__count, (_Sender &&) __sndr)); + } + + constexpr auto operator()(std::size_t __count) const + -> __binder_back { + return {{}, {}, {__count}}; + } + + template + auto transform_sender(_Sender &&__sndr, __ignore) { + return __sexpr_apply( + (_Sender &&) __sndr, [](__ignore, std::size_t __count, _Child __child) { + return __make_sexpr<__repeat_n_tag>(__child_count_pair{std::move(__child), __count}); + }); + } + }; + } // namespace __repeat_n + + using __repeat_n::repeat_n_t; + inline constexpr repeat_n_t repeat_n{}; +} // namespace exec + +namespace stdexec { + template <> + struct __sexpr_impl + : exec::__repeat_n::__repeat_n_impl { }; +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4477efacc..6741c7964 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -71,6 +71,7 @@ set(stdexec_test_sources exec/test_on2.cpp exec/test_on3.cpp exec/test_repeat_effect_until.cpp + exec/test_repeat_n.cpp exec/async_scope/test_dtor.cpp exec/async_scope/test_spawn.cpp exec/async_scope/test_spawn_future.cpp @@ -155,6 +156,14 @@ icm_add_build_failure_test( FOLDER test ) +icm_add_build_failure_test( + NAME test_repeat_n_fail + TARGET test_repeat_n_fail + SOURCES PARSE exec/test_repeat_n_fail.cpp + LIBRARIES stdexec + FOLDER test +) + # # Adding multiple tests with a glob # icm_glob_build_failure_tests( # PATTERN *_fail*.cpp diff --git a/test/exec/test_repeat_n.cpp b/test/exec/test_repeat_n.cpp new file mode 100644 index 000000000..7a67ce3f9 --- /dev/null +++ b/test/exec/test_repeat_n.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2023 Runner-2019 + * Copyright (c) 2023 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "exec/repeat_n.hpp" +#include "exec/on.hpp" +#include "exec/trampoline_scheduler.hpp" +#include "exec/static_thread_pool.hpp" +#include "stdexec/concepts.hpp" +#include "stdexec/execution.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace stdexec; + +namespace { + TEST_CASE("repeat_n returns a sender", "[adaptors][repeat_n]") { + auto snd = exec::repeat_n(ex::just() | then([] { }), 10); + static_assert(ex::sender); + (void) snd; + } + + TEST_CASE( + "repeat_n with environment returns a sender", + "[adaptors][repeat_n]") { + auto snd = exec::repeat_n(just() | then([] { }), 10); + static_assert(ex::sender_in); + (void) snd; + } + + TEST_CASE( + "repeat_n produces void value to downstream receiver", + "[adaptors][repeat_n]") { + sender auto source = just(1) | then([](int n) { }); + sender auto snd = exec::repeat_n(std::move(source), 10); + // The receiver checks if we receive the void value + auto op = stdexec::connect(std::move(snd), expect_void_receiver{}); + start(op); + } + + TEST_CASE("simple example for repeat_n", "[adaptors][repeat_n]") { + sender auto snd = exec::repeat_n(just(), 2); + stdexec::sync_wait(std::move(snd)); + } + + TEST_CASE("repeat_n works with with zero repetitions", "[adaptors][repeat_n]") { + std::size_t count = 0; + ex::sender auto snd = just() // + | then([&count] { ++count; }) + | exec::repeat_n(0) + | then([] { return 1; }); + wait_for_value(std::move(snd), 1); + CHECK(count == 0); + } + + TEST_CASE("repeat_n works with a single repetition", "[adaptors][repeat_n]") { + std::size_t count = 0; + ex::sender auto snd = just() // + | then([&count] { ++count; }) + | exec::repeat_n(1) + | then([] { return 1; }); + wait_for_value(std::move(snd), 1); + CHECK(count == 1); + } + + TEST_CASE("repeat_n works with multiple repetitions", "[adaptors][repeat_n]") { + std::size_t count = 0; + ex::sender auto snd = just() // + | then([&count] { ++count; }) + | exec::repeat_n(3) + | then([] { return 1; }); + wait_for_value(std::move(snd), 1); + CHECK(count == 3); + } + + TEST_CASE( + "repeat_n forwards set_error calls of other types", + "[adaptors][repeat_n]") { + int count = 0; + auto snd = let_value(just(), [&] { ++count; return just_error(std::string("error")); }) + | exec::repeat_n(10); + auto op = ex::connect(std::move(snd), expect_error_receiver{std::string("error")}); + start(op); + CHECK(count == 1); + } + + TEST_CASE("repeat_n forwards set_stopped calls", "[adaptors][repeat_n]") { + int count = 0; + auto snd = let_value(just(), [&] { ++count; return just_stopped(); }) + | exec::repeat_n(10); + auto op = ex::connect(std::move(snd), expect_stopped_receiver{}); + start(op); + CHECK(count == 1); + } + + TEST_CASE( + "running deeply recursing algo on repeat_n doesn't blow the stack", + "[adaptors][repeat_n]") { + int n = 0; + sender auto snd = exec::repeat_n(just() | then([&n] { ++n; }), 1'000'000); + stdexec::sync_wait(std::move(snd)); + CHECK(n == 1'000'000); + } + + TEST_CASE("repeat_n works when changing threads", "[adaptors][repeat_n]") { + exec::static_thread_pool pool{2}; + bool called{false}; + sender auto snd = exec::on( + pool.get_scheduler(), // + ex::just() // + | ex::then([&] { + called = true; + }) + | exec::repeat_n(10)); + stdexec::sync_wait(std::move(snd)); + REQUIRE(called); + } +} diff --git a/test/exec/test_repeat_n_fail.cpp b/test/exec/test_repeat_n_fail.cpp new file mode 100644 index 000000000..d2c1a7697 --- /dev/null +++ b/test/exec/test_repeat_n_fail.cpp @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace ex = stdexec; + +int main() { + ex::sender auto snd = ex::just(42) | exec::repeat_n(10); + // build error: _INVALID_ARGUMENT_TO_REPEAT_N_ + stdexec::sync_wait(std::move(snd)); +}