[libc++][PSTL] Implement std::count{,_if}
Reviewed By: ldionne, #libc Spies: libcxx-commits Differential Revision: https://reviews.llvm.org/D150128
This commit is contained in:
committed by
Nikolas Klauser
parent
30198bd788
commit
7a3b528e1b
@@ -85,6 +85,7 @@ set(files
|
||||
__algorithm/pstl_backends/cpu_backends/transform.h
|
||||
__algorithm/pstl_backends/cpu_backends/transform_reduce.h
|
||||
__algorithm/pstl_copy.h
|
||||
__algorithm/pstl_count.h
|
||||
__algorithm/pstl_fill.h
|
||||
__algorithm/pstl_find.h
|
||||
__algorithm/pstl_for_each.h
|
||||
|
||||
@@ -113,6 +113,12 @@ implemented, all the algorithms will eventually forward to the basis algorithms
|
||||
temlate <class _ExecutionPolicy, class _Iterator>
|
||||
__iter_value_type<_Iterator> __pstl_reduce(_Backend, _Iterator __first, _Iterator __last);
|
||||
|
||||
template <class _ExecuitonPolicy, class _Iterator, class _Tp>
|
||||
__iter_diff_t<_Iterator> __pstl_count(_Backend, _Iterator __first, _Iterator __last, const _Tp& __value);
|
||||
|
||||
template <class _ExecutionPolicy, class _Iterator, class _Predicate>
|
||||
__iter_diff_t<_Iterator> __pstl_count_if(_Backend, _Iterator __first, _Iterator __last, _Predicate __pred);
|
||||
|
||||
// TODO: Complete this list
|
||||
|
||||
*/
|
||||
|
||||
86
libcxx/include/__algorithm/pstl_count.h
Normal file
86
libcxx/include/__algorithm/pstl_count.h
Normal file
@@ -0,0 +1,86 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef _LIBCPP___ALGORITHM_PSTL_COUNT_H
|
||||
#define _LIBCPP___ALGORITHM_PSTL_COUNT_H
|
||||
|
||||
#include <__algorithm/count.h>
|
||||
#include <__algorithm/for_each.h>
|
||||
#include <__algorithm/pstl_backend.h>
|
||||
#include <__algorithm/pstl_for_each.h>
|
||||
#include <__algorithm/pstl_frontend_dispatch.h>
|
||||
#include <__atomic/atomic.h>
|
||||
#include <__config>
|
||||
#include <__iterator/iterator_traits.h>
|
||||
#include <__numeric/pstl_transform_reduce.h>
|
||||
#include <__type_traits/is_execution_policy.h>
|
||||
#include <__type_traits/remove_cvref.h>
|
||||
#include <__utility/terminate_on_exception.h>
|
||||
|
||||
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
|
||||
# pragma GCC system_header
|
||||
#endif
|
||||
|
||||
#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
|
||||
|
||||
_LIBCPP_BEGIN_NAMESPACE_STD
|
||||
|
||||
template <class>
|
||||
void __pstl_count_if(); // declaration needed for the frontend dispatch below
|
||||
|
||||
template <class _ExecutionPolicy,
|
||||
class _ForwardIterator,
|
||||
class _Predicate,
|
||||
class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
|
||||
enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
|
||||
_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
|
||||
count_if(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) {
|
||||
using __diff_t = __iter_diff_t<_ForwardIterator>;
|
||||
return std::__pstl_frontend_dispatch(
|
||||
_LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count_if),
|
||||
[&](_ForwardIterator __g_first, _ForwardIterator __g_last, _Predicate __g_pred) {
|
||||
return std::transform_reduce(
|
||||
__policy,
|
||||
std::move(__g_first),
|
||||
std::move(__g_last),
|
||||
__diff_t(),
|
||||
std::plus{},
|
||||
[&](__iter_reference<_ForwardIterator> __element) -> bool { return __g_pred(__element); });
|
||||
},
|
||||
std::move(__first),
|
||||
std::move(__last),
|
||||
std::move(__pred));
|
||||
}
|
||||
|
||||
template <class>
|
||||
void __pstl_count(); // declaration needed for the frontend dispatch below
|
||||
|
||||
template <class _ExecutionPolicy,
|
||||
class _ForwardIterator,
|
||||
class _Tp,
|
||||
class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
|
||||
enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
|
||||
_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
|
||||
count(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) {
|
||||
return std::__pstl_frontend_dispatch(
|
||||
_LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count),
|
||||
[&](_ForwardIterator __g_first, _ForwardIterator __g_last, const _Tp& __g_value) {
|
||||
return std::count_if(__policy, __g_first, __g_last, [&](__iter_reference<_ForwardIterator> __v) {
|
||||
return __v == __g_value;
|
||||
});
|
||||
},
|
||||
std::move(__first),
|
||||
std::move(__last),
|
||||
__value);
|
||||
}
|
||||
|
||||
_LIBCPP_END_NAMESPACE_STD
|
||||
|
||||
#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
|
||||
|
||||
#endif // _LIBCPP___ALGORITHM_PSTL_COUNT_H
|
||||
@@ -1802,6 +1802,7 @@ template <class BidirectionalIterator, class Compare>
|
||||
#include <__algorithm/prev_permutation.h>
|
||||
#include <__algorithm/pstl_any_all_none_of.h>
|
||||
#include <__algorithm/pstl_copy.h>
|
||||
#include <__algorithm/pstl_count.h>
|
||||
#include <__algorithm/pstl_fill.h>
|
||||
#include <__algorithm/pstl_find.h>
|
||||
#include <__algorithm/pstl_for_each.h>
|
||||
|
||||
@@ -42,6 +42,26 @@ bool __pstl_all_of(TestBackend, ForwardIterator, ForwardIterator, Pred) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pstl_count_called = false;
|
||||
|
||||
template <class, class ForwardIterator, class T>
|
||||
typename std::iterator_traits<ForwardIterator>::difference_type
|
||||
__pstl_count(TestBackend, ForwardIterator, ForwardIterator, const T&) {
|
||||
assert(!pstl_count_called);
|
||||
pstl_count_called = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pstl_count_if_called = false;
|
||||
|
||||
template <class, class ForwardIterator, class Pred>
|
||||
typename std::iterator_traits<ForwardIterator>::difference_type
|
||||
__pstl_count_if(TestBackend, ForwardIterator, ForwardIterator, Pred) {
|
||||
assert(!pstl_count_if_called);
|
||||
pstl_count_if_called = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pstl_none_of_called = false;
|
||||
|
||||
template <class, class ForwardIterator, class Pred>
|
||||
@@ -197,6 +217,10 @@ int main(int, char**) {
|
||||
assert(std::pstl_all_of_called);
|
||||
(void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
|
||||
assert(std::pstl_none_of_called);
|
||||
(void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
|
||||
assert(std::pstl_count_called);
|
||||
(void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);
|
||||
assert(std::pstl_count_if_called);
|
||||
(void)std::fill(TestPolicy{}, std::begin(a), std::end(a), 0);
|
||||
assert(std::pstl_fill_called);
|
||||
(void)std::fill_n(TestPolicy{}, std::begin(a), std::size(a), 0);
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// UNSUPPORTED: c++03, c++11, c++14
|
||||
|
||||
// UNSUPPORTED: libcpp-has-no-incomplete-pstl
|
||||
|
||||
// <algorithm>
|
||||
|
||||
// template<class ExecutionPolicy, class ForwardIterator, class T>
|
||||
// typename iterator_traits<ForwardIterator>::difference_type
|
||||
// count(ExecutionPolicy&& exec,
|
||||
// ForwardIterator first, ForwardIterator last, const T& value);
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "test_macros.h"
|
||||
#include "test_execution_policies.h"
|
||||
#include "test_iterators.h"
|
||||
|
||||
EXECUTION_POLICY_SFINAE_TEST(count);
|
||||
|
||||
static_assert(sfinae_test_count<int, int*, int*, bool (*)(int)>);
|
||||
static_assert(!sfinae_test_count<std::execution::parallel_policy, int*, int*, int>);
|
||||
|
||||
template <class Iter>
|
||||
struct Test {
|
||||
template <class Policy>
|
||||
void operator()(Policy&& policy) {
|
||||
{ // simple test
|
||||
int a[] = {1, 2, 3, 4, 5};
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 1);
|
||||
}
|
||||
|
||||
{ // test that an empty range works
|
||||
std::array<int, 0> a;
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 0);
|
||||
}
|
||||
|
||||
{ // test that a single-element range works
|
||||
int a[] = {1};
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 1);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 1);
|
||||
}
|
||||
|
||||
{ // test that a two-element range works
|
||||
int a[] = {1, 3};
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 1);
|
||||
}
|
||||
|
||||
{ // test that a three-element range works
|
||||
int a[] = {3, 1, 3};
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 2);
|
||||
}
|
||||
|
||||
{ // test that a large range works
|
||||
std::vector<int> a(100, 2);
|
||||
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 2);
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 100);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int main(int, char**) {
|
||||
types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// UNSUPPORTED: c++03, c++11, c++14
|
||||
|
||||
// UNSUPPORTED: libcpp-has-no-incomplete-pstl
|
||||
|
||||
// <algorithm>
|
||||
|
||||
// template<class ExecutionPolicy, class ForwardIterator, class Predicate>
|
||||
// typename iterator_traits<ForwardIterator>::difference_type
|
||||
// count_if(ExecutionPolicy&& exec,
|
||||
// ForwardIterator first, ForwardIterator last, Predicate pred);
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "test_macros.h"
|
||||
#include "test_execution_policies.h"
|
||||
#include "test_iterators.h"
|
||||
|
||||
EXECUTION_POLICY_SFINAE_TEST(count_if);
|
||||
|
||||
static_assert(sfinae_test_count_if<int, int*, int*, bool (*)(int)>);
|
||||
static_assert(!sfinae_test_count_if<std::execution::parallel_policy, int*, int*, int>);
|
||||
|
||||
template <class Iter>
|
||||
struct Test {
|
||||
template <class Policy>
|
||||
void operator()(Policy&& policy) {
|
||||
{ // simple test
|
||||
int a[] = {1, 2, 3, 4, 5};
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 2);
|
||||
}
|
||||
|
||||
{ // test that an empty range works
|
||||
std::array<int, 0> a;
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 0);
|
||||
}
|
||||
|
||||
{ // test that a single-element range works
|
||||
int a[] = {1};
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 1);
|
||||
}
|
||||
|
||||
{ // test that a two-element range works
|
||||
int a[] = {1, 3};
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 1);
|
||||
}
|
||||
|
||||
{ // test that a three-element range works
|
||||
int a[] = {2, 3, 2};
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 2);
|
||||
}
|
||||
|
||||
{ // test that a large range works
|
||||
std::vector<int> a(100, 2);
|
||||
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
|
||||
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
|
||||
assert(ret == 100);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int main(int, char**) {
|
||||
types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user