Skip to content

Commit

Permalink
Merge pull request #177 from stanleytsang-amd/xnack_on_hmm
Browse files Browse the repository at this point in the history
Cherry-picking HMM unit test support for ROCm 4.3
  • Loading branch information
stanleytsang-amd authored Jun 7, 2021
2 parents 84d8dcd + 489c11a commit bb4d0b7
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions test/test_device_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <cstdlib>

#include "test_header.hpp"

#define HIP_CHECK_HMM(condition) \
{ \
hipError_t error = condition; \
if(error != hipSuccess){ \
std::cout << "HIP error: " << error << " line: " << __LINE__ << std::endl; \
exit(error); \
} \
}

TESTS_DEFINE(DevicePtrTests, FullTestsParams);
TESTS_DEFINE(DevicePtrPrimitiveTests, NumericalTestsParams);

Expand All @@ -33,6 +43,31 @@ struct mark_processed_functor
}
};

bool supports_hmm()
{
hipDeviceProp_t device_prop;
int device_id;
HIP_CHECK_HMM(hipGetDevice(&device_id));
HIP_CHECK_HMM(hipGetDeviceProperties(&device_prop, device_id));
if (device_prop.managedMemory == 1) return true;

return false;
}

bool use_hmm()
{
if (getenv("ROCTHRUST_USE_HMM") == nullptr)
{
return false;
}

if (strcmp(getenv("ROCTHRUST_USE_HMM"), "1") == 0)
{
return true;
}
return false;
}

TEST(DevicePtrTests, TestDevicePointerManipulation)
{
SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
Expand Down Expand Up @@ -80,6 +115,58 @@ TEST(DevicePtrTests, TestDevicePointerManipulation)
ASSERT_EQ(end - begin, 5);
}

TEST(DevicePtrTests, TestDevicePointerManipulationHmm)
{
SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());

if (!(use_hmm())) return;

int* data;
HIP_CHECK(hipMallocManaged((void**)&data, 5 * sizeof(int)));

thrust::device_ptr<int> begin(&data[0]);
thrust::device_ptr<int> end(&data[0] + 5);

ASSERT_EQ(end - begin, 5);

begin++;
begin--;

ASSERT_EQ(end - begin, 5);

begin += 1;
begin -= 1;

ASSERT_EQ(end - begin, 5);

begin = begin + (int)1;
begin = begin - (int)1;

ASSERT_EQ(end - begin, 5);

begin = begin + (unsigned int)1;
begin = begin - (unsigned int)1;

ASSERT_EQ(end - begin, 5);

begin = begin + (size_t)1;
begin = begin - (size_t)1;

ASSERT_EQ(end - begin, 5);

begin = begin + (ptrdiff_t)1;
begin = begin - (ptrdiff_t)1;

ASSERT_EQ(end - begin, 5);

begin = begin + (thrust::device_ptr<int>::difference_type)1;
begin = begin - (thrust::device_ptr<int>::difference_type)1;

ASSERT_EQ(end - begin, 5);

HIP_CHECK(hipFree(data));
}

TYPED_TEST(DevicePtrPrimitiveTests, MakeDevicePointer)
{
using T = typename TestFixture::input_type;
Expand Down

0 comments on commit bb4d0b7

Please sign in to comment.