Skip to content

Commit

Permalink
create LLAMA iterators on the fly and revert hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed May 10, 2022
1 parent bf6a913 commit 803fc19
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 41 deletions.
77 changes: 51 additions & 26 deletions examples/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,27 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
return p;
};

template<typename View>
struct IndexToViewIterator
{
View view;
LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i)
{
return *(view.begin() + i);
}
};

template<typename View>
auto make_view_it(View view, std::size_t i)
{
auto ci = thrust::counting_iterator<std::size_t>{0};
return thrust::transform_iterator<
IndexToViewIterator<View>,
decltype(ci),
typename View::iterator::reference,
typename View::iterator::value_type>{ci, IndexToViewIterator<View>{std::move(view)}};
}

template<int Mapping>
void run(std::ostream& plotFile)
{
Expand Down Expand Up @@ -375,8 +396,16 @@ void run(std::ostream& plotFile)

auto view = llama::allocView(mapping, thrustDeviceAlloc);

auto b = make_view_it(view, 0);
auto e = make_view_it(view, N);
// auto b = view.begin();
// auto e = view.end();

auto r = (*b);
r(tag::eventId{}) = 0;

// touch memory once before running benchmarks
thrust::fill(thrust::device, view.begin(), view.end(), 0);
thrust::fill(thrust::device, b, e, 0);
syncWithCuda();

//#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
Expand Down Expand Up @@ -427,7 +456,7 @@ void run(std::ostream& plotFile)
}
else
{
thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{});
thrust::tabulate(thrust::device, b, e, InitOne{});
syncWithCuda();
}
tabulateTotal += stopwatch.printAndReset("tabulate", '\t');
Expand All @@ -453,7 +482,7 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
std::for_each(exec, view.begin(), view.end(), NormalizeVel{});
std::for_each(exec, b, e, NormalizeVel{});
else
{
thrust::for_each(
Expand All @@ -471,10 +500,10 @@ void run(std::ostream& plotFile)
thrust::device_vector<MassType> dst(N);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{});
std::transform(exec, b, e, dst.begin(), GetMass{});
else
{
thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{});
thrust::transform(thrust::device, b, e, dst.begin(), GetMass{});
syncWithCuda();
}
transformTotal += stopwatch.printAndReset("transform", '\t');
Expand All @@ -489,8 +518,8 @@ void run(std::ostream& plotFile)
if constexpr(usePSTL)
std::transform_exclusive_scan(
exec,
view.begin(),
view.end(),
b,
e,
scan_result.begin(),
std::uint32_t{0},
std::plus<>{},
Expand All @@ -499,8 +528,8 @@ void run(std::ostream& plotFile)
{
thrust::transform_exclusive_scan(
thrust::device,
view.begin(),
view.end(),
b,
e,
scan_result.begin(),
Predicate{},
std::uint32_t{0},
Expand All @@ -516,29 +545,24 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{});
sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{});
else
{
sink = thrust::transform_reduce(
thrust::device,
view.begin(),
view.end(),
GetMass{},
MassType{0},
thrust::plus<>{});
sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{});
syncWithCuda();
}
transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t');
}

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = make_view_it(dstView, 0);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy(exec, view.begin(), view.end(), dstView.begin());
std::copy(exec, b, e, db);
else
{
thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin());
thrust::copy(thrust::device, b, e, db);
syncWithCuda();
}
copyTotal += stopwatch.printAndReset("copy", '\t');
Expand All @@ -548,12 +572,13 @@ void run(std::ostream& plotFile)

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = make_view_it(dstView, 0);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{});
std::copy_if(exec, b, e, db, Predicate{});
else
{
thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{});
thrust::copy_if(thrust::device, b, e, db, Predicate{});
syncWithCuda();
}
copyIfTotal += stopwatch.printAndReset("copy_if", '\t');
Expand All @@ -564,10 +589,10 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
std::remove_if(exec, view.begin(), view.end(), Predicate{});
std::remove_if(exec, b, e, Predicate{});
else
{
thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{});
thrust::remove_if(thrust::device, b, e, Predicate{});
syncWithCuda();
}
removeIfTotal += stopwatch.printAndReset("remove_if", '\t');
Expand All @@ -576,14 +601,14 @@ void run(std::ostream& plotFile)
//{
// Stopwatch stopwatch;
// if constexpr(usePSTL)
// std::sort(std::execution::par, view.begin(), view.end(), Less{});
// std::sort(std::execution::par, b, e, Less{});
// else
// {
// thrust::sort(thrust::device, view.begin(), view.end(), Less{});
// thrust::sort(thrust::device, b, e, Less{});
// syncWithCuda();
// }
// sortTotal += stopwatch.printAndReset("sort", '\t');
// if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{}))
// if(!thrust::is_sorted(thrust::device, b, e, Less{}))
// std::cerr << "VALIDATION FAILED\n";
//}

Expand Down
10 changes: 5 additions & 5 deletions include/llama/ArrayIndexRange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ namespace llama

current[0] = static_cast<difference_type>(current[0]) + n;
// current is either within bounds or at the end ([last + 1, 0, 0, ..., 0])
//assert(
// (current[0] < extents[0]
// || (current[0] == extents[0]
// && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
// && "Iterator was moved past the end");
assert(
(current[0] < extents[0]
|| (current[0] == extents[0]
&& std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
&& "Iterator was moved past the end");

return *this;
}
Expand Down
16 changes: 8 additions & 8 deletions include/llama/View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ namespace llama

constexpr Iterator() = default;

LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View view)
LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View* view)
: arrayIndex(arrayIndex)
, view(std::move(view))
, view(view)
{
}

Expand Down Expand Up @@ -188,7 +188,7 @@ namespace llama
LLAMA_FN_HOST_ACC_INLINE
constexpr auto operator*() const -> reference
{
return const_cast<View&>(view)(*arrayIndex);
return (*view)(*arrayIndex);
}

LLAMA_FN_HOST_ACC_INLINE
Expand Down Expand Up @@ -283,7 +283,7 @@ namespace llama
}

ArrayIndexIterator arrayIndex;
View view;
View* view;
};

/// Using a mapping, maps the given array index and record coordinate to a memory reference onto the given blobs.
Expand Down Expand Up @@ -462,25 +462,25 @@ namespace llama
LLAMA_FN_HOST_ACC_INLINE
auto begin() -> iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto begin() const -> const_iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto end() -> iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto end() const -> const_iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
}

Array<BlobType, Mapping::blobCount> storageBlobs;
Expand Down
3 changes: 1 addition & 2 deletions include/llama/VirtualRecord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,7 @@ namespace llama
using ArrayIndex = typename View::Mapping::ArrayIndex;
using RecordDim = typename View::Mapping::RecordDim;

// std::conditional_t<OwnView, View, View&> view;
View view;
std::conditional_t<OwnView, View, View&> view;

public:
/// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is
Expand Down

0 comments on commit 803fc19

Please sign in to comment.