Skip to content

Commit

Permalink
Merge pull request NVIDIA#1267 from NVIDIA/member-function-customization
Browse files Browse the repository at this point in the history
start work of porting stdexec to use member functions instead of `tag_invoke`
  • Loading branch information
ericniebler authored Mar 2, 2024
2 parents af0efa9 + acde355 commit af4e218
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ KeepEmptyLinesAtTheStartOfBlocks: true
LambdaBodyIndentation: Signature
LineEnding: LF
Macros: [
'STDEXEC_MEMFN_DECL(X)=X',
'STDEXEC_MEMFN_DECL(...)=__VA_ARGS__',
'STDEXEC_ATTRIBUTE(X)=__attribute__(X) //',
'STDEXEC_NO_UNIQUE_ADDRESS=[[no_unique_address]]',
'STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS=[[no_unique_address]]',
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- "member-function-customization"
- "pull-request/[0-9]+"

concurrency:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- "member-function-customization"
- "pull-request/[0-9]+"

concurrency:
Expand Down
10 changes: 6 additions & 4 deletions include/exec/__detail/__bwos_lifo_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,14 @@ namespace exec::bwos {
auto lifo_queue<Tp, Allocator>::block_type::takeover() noexcept -> takeover_result {
std::uint64_t spos = steal_tail_.exchange(block_size(), std::memory_order_relaxed);
if (spos == block_size()) [[unlikely]] {
return {static_cast<std::size_t>(head_.load(std::memory_order_relaxed)),
static_cast<std::size_t>(tail_.load(std::memory_order_relaxed))};
return {
static_cast<std::size_t>(head_.load(std::memory_order_relaxed)),
static_cast<std::size_t>(tail_.load(std::memory_order_relaxed))};
}
head_.store(spos, std::memory_order_relaxed);
return {static_cast<std::size_t>(spos),
static_cast<std::size_t>(tail_.load(std::memory_order_relaxed))};
return {
static_cast<std::size_t>(spos),
static_cast<std::size_t>(tail_.load(std::memory_order_relaxed))};
}

template <class Tp, class Allocator>
Expand Down
2 changes: 1 addition & 1 deletion include/exec/static_thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ namespace exec {
using NextReceiver = stdexec::__t<next_receiver<Range, ReceiverId>>;
using ItemOperation = connect_result_t<NextSender, NextReceiver>;

using ItemAllocator = std::allocator_traits<Allocator>::template rebind_alloc<
using ItemAllocator = typename std::allocator_traits<Allocator>::template rebind_alloc<
__manual_lifetime<ItemOperation>>;

std::vector<__manual_lifetime<ItemOperation>, ItemAllocator> items_;
Expand Down
87 changes: 45 additions & 42 deletions include/stdexec/__detail/__basic_sender.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ namespace stdexec {
using __state_type_t =
__decay_t<__result_of<__sexpr_impl<_Tag>::get_state, _Sexpr, _Receiver&>>;

template <class _Tag, class _Index, class _Sexpr, class _Receiver>
template <class _Self, class _Tag, class _Index, class _Sexpr, class _Receiver>
using __env_type_t = __result_of<
__sexpr_impl<_Tag>::get_env,
__sexpr_impl<__meval<__msecond, _Self, _Tag>>::get_env,
_Index,
__state_type_t<_Tag, _Sexpr, _Receiver>&,
__state_type_t<__meval<__msecond, _Self, _Tag>, _Sexpr, _Receiver>&,
_Receiver&>;

template <class _Sexpr, class _Receiver>
Expand Down Expand Up @@ -245,18 +245,27 @@ namespace stdexec {
// return __t{__parent};
// }

template <__completion_tag _Tag, class... _Args>
template <class... _Args>
STDEXEC_ATTRIBUTE((always_inline))
friend void
tag_invoke(_Tag, __t&& __self, _Args&&... __args) noexcept {
__self.__op_->__complete(_Idx(), _Tag(), static_cast<_Args&&>(__args)...);
STDEXEC_MEMFN_DECL(void set_value)(this __t&& __self, _Args&&... __args) noexcept {
__self.__op_->__complete(_Idx(), stdexec::set_value, static_cast<_Args&&>(__args)...);
}

template <same_as<get_env_t> _Tag, class _SexprTag = __tag_t>
template <class _Error>
STDEXEC_ATTRIBUTE((always_inline))
friend auto
tag_invoke(_Tag, const __t& __self) noexcept
-> __env_type_t<_SexprTag, _Idx, _Sexpr, _Receiver> {
STDEXEC_MEMFN_DECL(void set_error)(this __t&& __self, _Error&& __err) noexcept {
__self.__op_->__complete(_Idx(), stdexec::set_error, static_cast<_Error&&>(__err));
}

STDEXEC_ATTRIBUTE((always_inline))
STDEXEC_MEMFN_DECL(void set_stopped)(this __t&& __self) noexcept {
__self.__op_->__complete(_Idx(), stdexec::set_stopped);
}

template <__same_as<__t> _Self>
STDEXEC_ATTRIBUTE((always_inline))
STDEXEC_MEMFN_DECL(auto get_env)(this const _Self& __self) noexcept
-> __env_type_t<_Self, __tag_t, _Idx, _Sexpr, _Receiver> {
return __self.__op_->__get_env(_Idx());
}
};
Expand All @@ -283,6 +292,10 @@ namespace stdexec {
auto __rcvr() & noexcept -> _Receiver& {
return __rcvr_;
}

auto __rcvr() const & noexcept -> const _Receiver& {
return __rcvr_;
}
};

// template <class _Sexpr, class _Receiver>
Expand Down Expand Up @@ -381,10 +394,8 @@ namespace stdexec {
__sexpr_apply(static_cast<_Sexpr&&>(__sexpr), __connect_fn<_Sexpr, _Receiver>{this})) {
}

template <same_as<start_t> _Tag2>
STDEXEC_ATTRIBUTE((always_inline))
friend void
tag_invoke(_Tag2, __op_state& __self) noexcept {
STDEXEC_MEMFN_DECL(void start)(this __op_state& __self) noexcept {
using __tag_t = typename __op_state::__tag_t;
auto&& __rcvr = __self.__rcvr();
__tup::__apply(
Expand All @@ -407,7 +418,8 @@ namespace stdexec {
template <class _Index>
STDEXEC_ATTRIBUTE((always_inline))
auto
__get_env(_Index) noexcept -> __env_type_t<__tag_t, _Index, _Sexpr, _Receiver> {
__get_env(_Index) const noexcept
-> __env_type_t<_Index, __tag_t, _Index, _Sexpr, _Receiver> {
const auto& __rcvr = this->__rcvr();
return __sexpr_impl<__tag_t>::get_env(_Index(), this->__state_, __rcvr);
}
Expand Down Expand Up @@ -516,43 +528,34 @@ namespace stdexec {
static_cast<_Child&&>(__child)...)) {
}

template <class _Tag>
using __impl = __sexpr_impl<__meval<__msecond, _Tag, __tag_t>>;
template <class _Self>
using __impl = __sexpr_impl<__meval<__msecond, _Self, __tag_t>>;

template <same_as<get_env_t> _Tag, same_as<__sexpr> _Self>
template <same_as<__sexpr> _Self>
STDEXEC_ATTRIBUTE((always_inline))
friend auto
tag_invoke(_Tag, const _Self& __self) noexcept //
-> __msecond<
__if_c<same_as<_Tag, get_env_t> && same_as<_Self, __sexpr>>, //
__result_of<__sexpr_apply, const _Self&, __get_attrs_fn<__tag_t>>> {
return __sexpr_apply(__self, __detail::__drop_front(__impl<_Tag>::get_attrs));
STDEXEC_MEMFN_DECL(auto get_env)(this const _Self& __self) noexcept //
-> __result_of<__sexpr_apply, const _Self&, __get_attrs_fn<__tag_t>> {
return __sexpr_apply(__self, __detail::__drop_front(__impl<_Self>::get_attrs));
}

template <same_as<get_completion_signatures_t> _Tag, __decays_to<__sexpr> _Self, class _Env>
template <__decays_to<__sexpr> _Self, class _Env>
STDEXEC_ATTRIBUTE((always_inline))
friend auto
tag_invoke(_Tag, _Self&& __self, _Env&& __env) noexcept //
STDEXEC_MEMFN_DECL(auto get_completion_signatures)(this _Self&& __self, _Env&& __env) noexcept
-> __msecond<
__if_c<same_as<_Tag, get_completion_signatures_t> && __decays_to<_Self, __sexpr>>,
__result_of<__impl<_Tag>::get_completion_signatures, _Self, _Env>> {
__if_c<__decays_to<_Self, __sexpr>>,
__result_of<__impl<_Self>::get_completion_signatures, _Self, _Env>> {
return {};
}

// BUGBUG fix receiver constraint here:
template <
same_as<connect_t> _Tag,
__decays_to<__sexpr> _Self,
/*receiver*/ class _Receiver>
template <__decays_to<__sexpr> _Self, /*receiver*/ class _Receiver>
STDEXEC_ATTRIBUTE((always_inline))
friend auto
tag_invoke(_Tag, _Self&& __self, _Receiver&& __rcvr) //
noexcept(noexcept(
__impl<_Tag>::connect(static_cast<_Self&&>(__self), static_cast<_Receiver&&>(__rcvr)))) //
STDEXEC_MEMFN_DECL(auto connect)(this _Self&& __self, _Receiver&& __rcvr) //
noexcept(__noexcept_of<__impl<_Self>::connect, _Self, _Receiver>) //
-> __msecond<
__if_c<same_as<_Tag, connect_t> && __decays_to<_Self, __sexpr>>,
__result_of<__impl<_Tag>::connect, _Self, _Receiver>> {
return __impl<_Tag>::connect(static_cast<_Self&&>(__self), static_cast<_Receiver&&>(__rcvr));
__if_c<__decays_to<_Self, __sexpr>>,
__result_of<__impl<_Self>::connect, _Self, _Receiver>> {
return __impl<_Self>::connect(static_cast<_Self&&>(__self), static_cast<_Receiver&&>(__rcvr));
}

template <class _Sender, class _ApplyFn>
Expand All @@ -566,8 +569,8 @@ namespace stdexec {

template <std::size_t _Idx, __decays_to_derived_from<__sexpr> _Self>
STDEXEC_ATTRIBUTE((always_inline))
friend decltype(auto)
get(_Self&& __self) noexcept
friend auto
get(_Self&& __self) noexcept -> decltype(auto)
requires __detail::__in_range<_Idx, __desc_t>
{
if constexpr (_Idx == 0) {
Expand Down
42 changes: 19 additions & 23 deletions include/stdexec/__detail/__cpo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
#include "__config.hpp"
#include "__execution_fwd.hpp"

#define STDEXEC_EAT_THIS_this
#define STDEXEC_EAT_AUTO_auto
#define STDEXEC_EAT_VOID_void

///////////////////////////////////////////////////////////////////////////////
/// To hook a customization point like stdexec::get_env, first bring the names
/// in stdexec::tags into scope:
Expand All @@ -38,31 +34,31 @@
/// }
/// @endcode
#define STDEXEC_MEMFN_DECL(...) \
friend STDEXEC_TAG_INVOKE(STDEXEC_IS_AUTO(__VA_ARGS__), __VA_ARGS__) STDEXEC_TAG_INVOKE_ARGS
friend STDEXEC_MEMFN_DECL_TAG_INVOKE( \
STDEXEC_CHECK(STDEXEC_CAT(STDEXEC_MEMFN_DECL_PROBE_, __VA_ARGS__)), __VA_ARGS__) \
STDEXEC_MEMFN_DECL_ARGS

#define STDEXEC_TAG_INVOKE(_ISAUTO, ...) \
STDEXEC_IIF(_ISAUTO, STDEXEC_RETURN_AUTO, STDEXEC_RETURN_TYPE)(__VA_ARGS__) \
tag_invoke( \
STDEXEC_IIF(_ISAUTO, STDEXEC_TAG_AUTO, STDEXEC_TAG_WHAT)(__VA_ARGS__)
#define STDEXEC_MEMFN_DECL_TAG_INVOKE(_WHICH, ...) \
STDEXEC_CAT(STDEXEC_MEMFN_DECL_RETURN_, _WHICH)(__VA_ARGS__) \
tag_invoke(const STDEXEC_CAT(STDEXEC_MEMFN_DECL_TAG_, _WHICH)(__VA_ARGS__),

#define STDEXEC_PROBE_AUTO_auto STDEXEC_PROBE(~)
#define STDEXEC_IS_AUTO(_TY, ...) STDEXEC_CHECK(STDEXEC_CAT(STDEXEC_PROBE_AUTO_, _TY))
#define STDEXEC_MEMFN_DECL_ARGS(...) \
STDEXEC_CAT(STDEXEC_EAT_THIS_, __VA_ARGS__))

#define STDEXEC_PROBE_VOID_void STDEXEC_PROBE(~)
#define STDEXEC_IS_VOID(_TY, ...) STDEXEC_CHECK(STDEXEC_CAT(STDEXEC_PROBE_VOID_, _TY))

#define STDEXEC_RETURN_AUTO(...) auto
#define STDEXEC_RETURN_TYPE(...) ::stdexec::__arg_type_t<void(__VA_ARGS__())>
#define STDEXEC_EAT_THIS_this
#define STDEXEC_EAT_AUTO_auto
#define STDEXEC_EAT_VOID_void

#define STDEXEC_TAG_AUTO(...) STDEXEC_CAT(STDEXEC_CAT(STDEXEC_EAT_AUTO_, __VA_ARGS__), _t)
#define STDEXEC_TAG_WHAT(...) \
STDEXEC_IIF(STDEXEC_IS_VOID(__VA_ARGS__), STDEXEC_TAG_VOID, STDEXEC_TAG_TYPE)(__VA_ARGS__)
#define STDEXEC_MEMFN_DECL_PROBE_auto STDEXEC_PROBE(~, 1)
#define STDEXEC_MEMFN_DECL_PROBE_void STDEXEC_PROBE(~, 2)

#define STDEXEC_TAG_VOID(...) STDEXEC_CAT(STDEXEC_CAT(STDEXEC_EAT_VOID_, __VA_ARGS__), _t)
#define STDEXEC_TAG_TYPE(...) ::stdexec::__tag_type_t<STDEXEC_CAT(__VA_ARGS__, _t::*)>
#define STDEXEC_MEMFN_DECL_RETURN_0(...) ::stdexec::__arg_type_t<void(__VA_ARGS__())>
#define STDEXEC_MEMFN_DECL_RETURN_1(...) auto
#define STDEXEC_MEMFN_DECL_RETURN_2(...) void

#define STDEXEC_TAG_INVOKE_ARGS(...) \
__VA_OPT__(,) STDEXEC_CAT(STDEXEC_EAT_THIS_, __VA_ARGS__))
#define STDEXEC_MEMFN_DECL_TAG_0(...) ::stdexec::__tag_type_t<STDEXEC_CAT(__VA_ARGS__, _t::*)>&
#define STDEXEC_MEMFN_DECL_TAG_1(...) STDEXEC_CAT(STDEXEC_CAT(STDEXEC_EAT_AUTO_, __VA_ARGS__), _t)&
#define STDEXEC_MEMFN_DECL_TAG_2(...) STDEXEC_CAT(STDEXEC_CAT(STDEXEC_EAT_VOID_, __VA_ARGS__), _t)&

#if STDEXEC_MSVC()
# pragma deprecated(STDEXEC_CUSTOM)
Expand Down
9 changes: 9 additions & 0 deletions include/stdexec/__detail/__meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,17 @@ namespace stdexec {
using __call_result_t = decltype(__declval<_Fun>()(__declval<_As>()...));
#endif

// BUGBUG TODO file this bug with nvc++
#if STDEXEC_NVHPC()
template <const auto& _Fun, class... _As>
using __result_of = __call_result_t<decltype(_Fun), _As...>;
#else
template <const auto& _Fun, class... _As>
using __result_of = decltype(_Fun(__declval<_As>()...));
#endif

template <const auto& _Fun, class... _As>
inline constexpr bool __noexcept_of = noexcept(_Fun(__declval<_As>()...));

// For working around clang's lack of support for CWG#2369:
// http://www.open-std.org/jtc1/sc22/wg21/docs/cwg_defects.html#2369
Expand Down
2 changes: 1 addition & 1 deletion include/stdexec/execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4725,7 +4725,7 @@ namespace stdexec {
};

template <class _Env>
auto __mkenv(_Env&& __env, in_place_stop_source& __stop_source) noexcept {
auto __mkenv(_Env&& __env, const in_place_stop_source& __stop_source) noexcept {
return __env::__join(
__env::__with(__stop_source.get_token(), get_stop_token), static_cast<_Env&&>(__env));
}
Expand Down

0 comments on commit af4e218

Please sign in to comment.