Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Dec 2, 2021
1 parent 431b044 commit 3653ff6
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions examples/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,29 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
return p;
};

template <typename View>
struct ViewIteratorAt
{
View view;

LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i)
{
return *(view.begin() + i);
}
};

template<typename View>
auto viewIteratorAt(View& view, std::size_t index)
{
ViewIteratorAt<View> t{view};
using ViewTransformIterator = thrust::transform_iterator<
decltype(t),
thrust::counting_iterator<std::size_t>,
typename View::iterator::reference,
typename View::iterator::value_type>;
return ViewTransformIterator{thrust::counting_iterator<std::size_t>{index}, t};
}

template<int Mapping>
void run(std::ostream& plotFile)
{
Expand Down Expand Up @@ -374,15 +397,8 @@ void run(std::ostream& plotFile)
std::cout << mappingName << '\n';

auto view = llama::allocView(mapping, thrustDeviceAlloc);

auto makeViewIteratorFromIndexCreator = [](decltype(view) view)
{ return [view] __host__ __device__(std::size_t i) mutable { return *(view.begin() + i); }; };
auto b = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(view));
auto e = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{N},
makeViewIteratorFromIndexCreator(view));
auto b = viewIteratorAt(view, 0);
auto e = viewIteratorAt(view, N);
// auto b = view.begin();
// auto e = view.end();

Expand Down Expand Up @@ -541,9 +557,7 @@ void run(std::ostream& plotFile)

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(dstView));
auto db = viewIteratorAt(dstView, 0);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy(exec, b, e, db);
Expand All @@ -559,9 +573,7 @@ void run(std::ostream& plotFile)

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(dstView));
auto db = viewIteratorAt(dstView, 0);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy_if(exec, b, e, db, Predicate{});
Expand Down

0 comments on commit 3653ff6

Please sign in to comment.