[libc++][PSTL] Implement std::count{,_if}

Reviewed By: ldionne, #libc

Spies: libcxx-commits

Differential Revision: https://reviews.llvm.org/D150128
This commit is contained in:
Nikolas Klauser
2023-06-06 08:42:40 -07:00
committed by Nikolas Klauser
parent 30198bd788
commit 7a3b528e1b
7 changed files with 290 additions and 0 deletions

View File

@@ -85,6 +85,7 @@ set(files
__algorithm/pstl_backends/cpu_backends/transform.h
__algorithm/pstl_backends/cpu_backends/transform_reduce.h
__algorithm/pstl_copy.h
__algorithm/pstl_count.h
__algorithm/pstl_fill.h
__algorithm/pstl_find.h
__algorithm/pstl_for_each.h

View File

@@ -113,6 +113,12 @@ implemented, all the algorithms will eventually forward to the basis algorithms
temlate <class _ExecutionPolicy, class _Iterator>
__iter_value_type<_Iterator> __pstl_reduce(_Backend, _Iterator __first, _Iterator __last);
template <class _ExecuitonPolicy, class _Iterator, class _Tp>
__iter_diff_t<_Iterator> __pstl_count(_Backend, _Iterator __first, _Iterator __last, const _Tp& __value);
template <class _ExecutionPolicy, class _Iterator, class _Predicate>
__iter_diff_t<_Iterator> __pstl_count_if(_Backend, _Iterator __first, _Iterator __last, _Predicate __pred);
// TODO: Complete this list
*/

View File

@@ -0,0 +1,86 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef _LIBCPP___ALGORITHM_PSTL_COUNT_H
#define _LIBCPP___ALGORITHM_PSTL_COUNT_H
#include <__algorithm/count.h>
#include <__algorithm/for_each.h>
#include <__algorithm/pstl_backend.h>
#include <__algorithm/pstl_for_each.h>
#include <__algorithm/pstl_frontend_dispatch.h>
#include <__atomic/atomic.h>
#include <__config>
#include <__iterator/iterator_traits.h>
#include <__numeric/pstl_transform_reduce.h>
#include <__type_traits/is_execution_policy.h>
#include <__type_traits/remove_cvref.h>
#include <__utility/terminate_on_exception.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
#endif
#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
_LIBCPP_BEGIN_NAMESPACE_STD
template <class>
void __pstl_count_if(); // declaration needed for the frontend dispatch below
template <class _ExecutionPolicy,
class _ForwardIterator,
class _Predicate,
class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
count_if(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) {
using __diff_t = __iter_diff_t<_ForwardIterator>;
return std::__pstl_frontend_dispatch(
_LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count_if),
[&](_ForwardIterator __g_first, _ForwardIterator __g_last, _Predicate __g_pred) {
return std::transform_reduce(
__policy,
std::move(__g_first),
std::move(__g_last),
__diff_t(),
std::plus{},
[&](__iter_reference<_ForwardIterator> __element) -> bool { return __g_pred(__element); });
},
std::move(__first),
std::move(__last),
std::move(__pred));
}
template <class>
void __pstl_count(); // declaration needed for the frontend dispatch below
template <class _ExecutionPolicy,
class _ForwardIterator,
class _Tp,
class _RawPolicy = __remove_cvref_t<_ExecutionPolicy>,
enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
count(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) {
return std::__pstl_frontend_dispatch(
_LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count),
[&](_ForwardIterator __g_first, _ForwardIterator __g_last, const _Tp& __g_value) {
return std::count_if(__policy, __g_first, __g_last, [&](__iter_reference<_ForwardIterator> __v) {
return __v == __g_value;
});
},
std::move(__first),
std::move(__last),
__value);
}
_LIBCPP_END_NAMESPACE_STD
#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
#endif // _LIBCPP___ALGORITHM_PSTL_COUNT_H

View File

@@ -1802,6 +1802,7 @@ template <class BidirectionalIterator, class Compare>
#include <__algorithm/prev_permutation.h>
#include <__algorithm/pstl_any_all_none_of.h>
#include <__algorithm/pstl_copy.h>
#include <__algorithm/pstl_count.h>
#include <__algorithm/pstl_fill.h>
#include <__algorithm/pstl_find.h>
#include <__algorithm/pstl_for_each.h>

View File

@@ -42,6 +42,26 @@ bool __pstl_all_of(TestBackend, ForwardIterator, ForwardIterator, Pred) {
return true;
}
bool pstl_count_called = false;
template <class, class ForwardIterator, class T>
typename std::iterator_traits<ForwardIterator>::difference_type
__pstl_count(TestBackend, ForwardIterator, ForwardIterator, const T&) {
assert(!pstl_count_called);
pstl_count_called = true;
return 0;
}
bool pstl_count_if_called = false;
template <class, class ForwardIterator, class Pred>
typename std::iterator_traits<ForwardIterator>::difference_type
__pstl_count_if(TestBackend, ForwardIterator, ForwardIterator, Pred) {
assert(!pstl_count_if_called);
pstl_count_if_called = true;
return 0;
}
bool pstl_none_of_called = false;
template <class, class ForwardIterator, class Pred>
@@ -197,6 +217,10 @@ int main(int, char**) {
assert(std::pstl_all_of_called);
(void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
assert(std::pstl_none_of_called);
(void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
assert(std::pstl_count_called);
(void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);
assert(std::pstl_count_if_called);
(void)std::fill(TestPolicy{}, std::begin(a), std::end(a), 0);
assert(std::pstl_fill_called);
(void)std::fill_n(TestPolicy{}, std::begin(a), std::size(a), 0);

View File

@@ -0,0 +1,86 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// UNSUPPORTED: c++03, c++11, c++14
// UNSUPPORTED: libcpp-has-no-incomplete-pstl
// <algorithm>
// template<class ExecutionPolicy, class ForwardIterator, class T>
// typename iterator_traits<ForwardIterator>::difference_type
// count(ExecutionPolicy&& exec,
// ForwardIterator first, ForwardIterator last, const T& value);
#include <algorithm>
#include <array>
#include <cassert>
#include <vector>
#include "test_macros.h"
#include "test_execution_policies.h"
#include "test_iterators.h"
EXECUTION_POLICY_SFINAE_TEST(count);
static_assert(sfinae_test_count<int, int*, int*, bool (*)(int)>);
static_assert(!sfinae_test_count<std::execution::parallel_policy, int*, int*, int>);
template <class Iter>
struct Test {
template <class Policy>
void operator()(Policy&& policy) {
{ // simple test
int a[] = {1, 2, 3, 4, 5};
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 1);
}
{ // test that an empty range works
std::array<int, 0> a;
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 0);
}
{ // test that a single-element range works
int a[] = {1};
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 1);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 1);
}
{ // test that a two-element range works
int a[] = {1, 3};
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 1);
}
{ // test that a three-element range works
int a[] = {3, 1, 3};
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 2);
}
{ // test that a large range works
std::vector<int> a(100, 2);
decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 2);
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 100);
}
}
};
int main(int, char**) {
types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
return 0;
}

View File

@@ -0,0 +1,86 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// UNSUPPORTED: c++03, c++11, c++14
// UNSUPPORTED: libcpp-has-no-incomplete-pstl
// <algorithm>
// template<class ExecutionPolicy, class ForwardIterator, class Predicate>
// typename iterator_traits<ForwardIterator>::difference_type
// count_if(ExecutionPolicy&& exec,
// ForwardIterator first, ForwardIterator last, Predicate pred);
#include <algorithm>
#include <array>
#include <cassert>
#include <vector>
#include "test_macros.h"
#include "test_execution_policies.h"
#include "test_iterators.h"
EXECUTION_POLICY_SFINAE_TEST(count_if);
static_assert(sfinae_test_count_if<int, int*, int*, bool (*)(int)>);
static_assert(!sfinae_test_count_if<std::execution::parallel_policy, int*, int*, int>);
template <class Iter>
struct Test {
template <class Policy>
void operator()(Policy&& policy) {
{ // simple test
int a[] = {1, 2, 3, 4, 5};
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 2);
}
{ // test that an empty range works
std::array<int, 0> a;
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 0);
}
{ // test that a single-element range works
int a[] = {1};
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 1);
}
{ // test that a two-element range works
int a[] = {1, 3};
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 1);
}
{ // test that a three-element range works
int a[] = {2, 3, 2};
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 2);
}
{ // test that a large range works
std::vector<int> a(100, 2);
decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
assert(ret == 100);
}
}
};
int main(int, char**) {
types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
return 0;
}