diff --git a/include/exec/function.hpp b/include/exec/function.hpp index 90375c8dd..0fbfc9955 100644 --- a/include/exec/function.hpp +++ b/include/exec/function.hpp @@ -17,6 +17,7 @@ #include "../stdexec/__detail/__completion_signatures.hpp" #include "../stdexec/__detail/__concepts.hpp" +#include "../stdexec/__detail/__domain.hpp" #include "../stdexec/__detail/__meta.hpp" #include "../stdexec/__detail/__read_env.hpp" #include "../stdexec/__detail/__receivers.hpp" @@ -55,6 +56,11 @@ // queries to pick the frame allocator from the environment without relying on TLS. namespace experimental::execution { + // for specifying required sender attributes in exec::function + template <_query::_query_signature... Sigs> + struct attrs + {}; + namespace __func { using namespace STDEXEC; @@ -106,13 +112,13 @@ namespace experimental::execution {} constexpr auto get_env() const noexcept // - -> __join_env_t<__prop_t, env_of_t<_Receiver>> + -> __join_env_t<__prop_t const &, env_of_t<_Receiver>> { return __env::__join(*__env_, STDEXEC::get_env(*static_cast<_Receiver const *>(this))); } private: - __prop_t *__env_; + __prop_t const *__env_; }; template @@ -192,7 +198,74 @@ namespace experimental::execution } }; - template + template + struct __make_domain_impl + {}; + + template + struct __make_domain_impl<_Tag, _Domain(get_completion_domain_t<_Tag>) noexcept> + { + constexpr _Domain operator()() const noexcept + { + return _Domain(); + } + }; + + //! get_completion_domain<> is a special case; its type parameter is void + //! and it's equivalent to get_completion_domain. + template + struct __make_domain_impl) noexcept> + : __make_domain_impl) noexcept> + {}; + + //! get_completion_domain ought to be no-throw, so make it optional to specify + //! noexcept on the signature provided with attrs<...> + template + struct __make_domain_impl<_Tag1, _Domain(get_completion_domain_t<_Tag2>)> + : __make_domain_impl<_Tag1, _Domain(get_completion_domain_t<_Tag2>) noexcept> + {}; + + template + inline constexpr auto __make_domain = __first_callable<__make_domain_impl<_Tag, _Attrs>...>(); + + template + struct __attrs + { + template + constexpr auto query(get_completion_domain_t<_Tag>, _Env &&...) const noexcept + -> decltype(__make_domain<_Tag, _Attrs...>()) + { + return __make_domain<_Tag, _Attrs...>(); + } + }; + + template + using __completion_domain_t = __call_result_or_t< + get_completion_domain_t<_Tag>, + __call_result_or_t, indeterminate_domain<>, _Attrs>, + _Attrs, + _Env const &...>; + + template + concept __completion_domain_matches_impl = + __same_as<_ExpectedDomain, __common_domain_t<_ActualDomain, _ExpectedDomain>>; + + template + concept __completion_domain_matches = + __completion_domain_matches_impl<__completion_domain_t<_Tag, _ActualAttrs, _Env...>, + __completion_domain_t<_Tag, _ExpectedAttrs, _Env...>>; + + template + concept __completion_domains_match_impl = + __completion_domain_matches + && __completion_domain_matches + && __completion_domain_matches; + + template + concept __completion_domains_match = + __completion_domains_match_impl, env_of_t<_Expected>, _Env...>; + + template class __function; //! the main implementation of the type-erasing sender function<...> @@ -206,8 +279,8 @@ namespace experimental::execution //! not, as appropriate //! //! \tparam _Args The argument types used to construct the erased sender - template - class __function<_Sigs, queries<_Queries...>, _Args...> + template + class __function<_Sigs, queries<_Queries...>, attrs<_Attrs...>, _Args...> { using __receiver_t = __receiver_wrapper<__any_receiver_ref<_Sigs, queries<_Queries...>>>; @@ -220,8 +293,9 @@ namespace experimental::execution -> _any::_any_opstate_base { auto &__make_sender = *__std::start_lifetime_as<_Factory>(__storage); - using __alloc_t = decltype(__choose_frame_allocator(get_env(__rcvr))); - auto __alloc = __frame_allocator_t<__alloc_t>(__choose_frame_allocator(get_env(__rcvr))); + using __alloc_t = decltype(__choose_frame_allocator(STDEXEC::get_env(__rcvr))); + auto __alloc = __frame_allocator_t<__alloc_t>( + __choose_frame_allocator(STDEXEC::get_env(__rcvr))); return _any::_any_opstate_base(__in_place_from, std::allocator_arg, __alloc, @@ -257,6 +331,9 @@ namespace experimental::execution && (STDEXEC_IS_TRIVIALLY_COPYABLE(_Factory)) // && (sizeof(_Factory) <= sizeof(__make_sender_)) // && sender_to<__invoke_result_t<_Factory, _Args...>, __receiver_t> + && __completion_domains_match<__invoke_result_t<_Factory, _Args...>, + __function, + env_of_t<__receiver_t>> constexpr explicit __function(_Args &&...__args, _Factory __factory) noexcept(__nothrow_move_constructible<_Args...>) : __args_(static_cast<_Args &&>(__args)...) @@ -279,6 +356,11 @@ namespace experimental::execution return _Sigs(); } + constexpr __attrs<_Attrs...> get_env() const noexcept + { + return {}; + } + template constexpr auto connect(_Receiver __rcvr) && // -> __opstate_t<_Receiver> @@ -342,6 +424,46 @@ namespace experimental::execution completion_signatures<__single_value_sig_t<_Return>, set_stopped_t()>, __eptr_completion_unless_t<__mbool<_NoExcept>>>>; + //! maps a completion signature to the default completion domain query + struct __domain_query_from_sig + { + template + consteval auto operator()(_Tag (*)(_Args...)) const noexcept // + -> default_domain (*)(get_completion_domain_t<_Tag>) noexcept + { + return nullptr; + } + }; + + //! maps a pack of domain queries produced by __domain_query_from_sig to the + //! corresponding attrs<_Attrs...> type + class __attrs_from_domain_queries + { + template + using __query_sig = default_domain (*)(get_completion_domain_t<_Tag>) noexcept; + + public: + template + consteval auto operator()(__query_sig<_Tag>...) const noexcept // + -> __canonical_t) noexcept...>> + { + return {}; + } + }; + + //! computes the set of get_completion_domain queries that must be supported by any + //! sender that might be erased by the corresponding function + //! + //! we should support get_completion_domain only if _Sigs contains a completion + //! of type Tag + //! + //! the query form should be + //! + //! default_domain(get_completion_domain_t) + template + using __default_attrs = decltype(_Sigs::__transform_reduce(__domain_query_from_sig(), + __attrs_from_domain_queries())); + //! Map a variety of function<...> specifications into the canonical type-erased //! contract represented by the user-provided specification. //! @@ -362,48 +484,88 @@ namespace experimental::execution //! The order of Args... is obviously important, but Sigs... and Queries... are both //! canonicalized into a sorted and uniqued list to ensure order is irrelevant. template - struct __make_function; + class __make_function; template - struct __make_function<_Return(_Args...)> + class __make_function<_Return(_Args...)> { - using type = __function<__sigs_from_t<_Return, false>, queries<>, _Args...>; + using __sigs = __sigs_from_t<_Return, false>; + using __queries = queries<>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; template - struct __make_function<_Return(_Args...) noexcept> + class __make_function<_Return(_Args...) noexcept> { - using type = __function<__sigs_from_t<_Return, true>, queries<>, _Args...>; + using __sigs = __sigs_from_t<_Return, true>; + using __queries = queries<>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; template - struct __make_function> + class __make_function> { - using type = __function<__canonical_t>, queries<>, _Args...>; + using __sigs = __canonical_t>; + using __queries = queries<>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; template - struct __make_function<_Return(_Args...), queries<_Queries...>> + class __make_function<_Return(_Args...), queries<_Queries...>> { - using type = - __function<__sigs_from_t<_Return, false>, __canonical_t>, _Args...>; + using __sigs = __sigs_from_t<_Return, false>; + using __queries = __canonical_t>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; template - struct __make_function<_Return(_Args...) noexcept, queries<_Queries...>> + class __make_function<_Return(_Args...) noexcept, queries<_Queries...>> { - using type = - __function<__sigs_from_t<_Return, true>, __canonical_t>, _Args...>; + using __sigs = __sigs_from_t<_Return, true>; + using __queries = __canonical_t>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; template - struct __make_function, - queries<_Queries...>> + class __make_function, + queries<_Queries...>> + { + using __sigs = __canonical_t>; + using __queries = __canonical_t>; + using __attrs = __default_attrs<__sigs>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; + }; + + template + class __make_function, + queries<_Queries...>, + attrs<_Attrs...>> { - using type = __function<__canonical_t>, - __canonical_t>, - _Args...>; + using __sigs = __canonical_t>; + using __queries = __canonical_t>; + using __attrs = __canonical_t>; + + public: + using type = __function<__sigs, __queries, __attrs, _Args...>; }; } // namespace __func diff --git a/test/exec/test_function.cpp b/test/exec/test_function.cpp index 6283df70e..13160fc70 100644 --- a/test/exec/test_function.cpp +++ b/test/exec/test_function.cpp @@ -411,4 +411,145 @@ namespace STATIC_REQUIRE(std::assignable_from); } } + + struct none_such + {}; + + template + inline constexpr auto get_completion_domain = + ex::__first_callable{ex::get_completion_domain, ex::__always{none_such()}}; + + TEST_CASE("function reports a default completion domain by default", "[types][function]") + { + SECTION("throwing function reports a completion domain for all three channels") + { + exec::function fn(ex::just); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("no-throw function reports a completion domain for value and stop channels only") + { + exec::function fn(ex::just); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("infallible function reports a completion domain for value channel only") + { + exec::function> fn(ex::just); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("just_error function reports a completion domain for error channel only") + { + exec::function> fn( + 42, + ex::just_error); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("just_stopped function reports a completion domain for stop channel only") + { + exec::function> fn( + ex::just_stopped); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + } + + struct domain : ex::default_domain + {}; + + TEST_CASE("function's constructor is constrained based on the common domain", "[types][function]") + { + using queries = exec::queries; + + SECTION("the constraint applies to set_value") + { + using function = + exec::function, queries>; + + STATIC_REQUIRE(std::constructible_from); + + function fn(ex::just); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("the constraint applies to set_error") + { + using function = exec::function, + queries>; + + STATIC_REQUIRE(std::constructible_from); + + function fn(42, ex::just_error); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + + SECTION("the constraint applies to set_stopped") + { + using function = + exec::function, queries>; + + STATIC_REQUIRE(std::constructible_from); + + function fn(ex::just_stopped); + auto attrs = ex::get_env(fn); + auto value_domain = get_completion_domain(attrs); + auto error_domain = get_completion_domain(attrs); + auto stop_domain = get_completion_domain(attrs); + + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(std::same_as); + } + } } // namespace