diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp index 075a17539c7ef..f65e8c85c55e5 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -50,7 +50,7 @@ struct sub_group_mask { } reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) { - RefBit = 1 << pos % word_size; + RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0; } private: @@ -61,16 +61,17 @@ struct sub_group_mask { }; bool operator[](id<1> id) const { - return Bits & (1 << (id.get(0) % word_size)); + return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0)); } + reference operator[](id<1> id) { return {*this, id.get(0)}; } bool test(id<1> id) const { return operator[](id); } - bool all() const { return !~Bits; } - bool any() const { return Bits; } - bool none() const { return !Bits; } + bool all() const { return count() == bits_num; } + bool any() const { return count() != 0; } + bool none() const { return count() == 0; } uint32_t count() const { unsigned int count = 0; - auto word = Bits; + auto word = (Bits & valuable_bits(bits_num)); while (word) { word &= (word - 1); count++; @@ -99,9 +100,9 @@ struct sub_group_mask { insert_data <<= pos.get(0); uint32_t mask = 0; if (pos.get(0) + insert_size < size()) - mask |= (0xffffffff << (pos.get(0) + insert_size)); + mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size)); if (pos.get(0) < size() && pos.get(0)) - mask |= (0xffffffff >> (size() - pos.get(0))); + mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0))); Bits &= mask; Bits += insert_data; } @@ -125,14 +126,15 @@ struct sub_group_mask { template ::value>> void extract_bits(Type &bits, id<1> pos = 0) const { - uint32_t Res = Bits; + auto Res = Bits; + Res &= valuable_bits(bits_num); if (pos.get(0) < size()) { if (pos.get(0) > 0) { Res >>= pos.get(0); } - if (sizeof(Type) * CHAR_BIT < size()) { - Res &= (0xffffffff >> (size() - (sizeof(Type) * CHAR_BIT))); + if (sizeof(Type) * CHAR_BIT < max_bits) { + Res &= valuable_bits(sizeof(Type) * CHAR_BIT); } bits = (Type)Res; } else { @@ -154,13 +156,13 @@ struct sub_group_mask { } } - void set() { Bits = uint32_t{0xffffffff}; } + void set() { Bits = valuable_bits(bits_num); } void set(id<1> id, bool value = true) { operator[](id) = value; } void reset() { Bits = uint32_t{0}; } void reset(id<1> id) { operator[](id) = 0; } void reset_low() { reset(find_low()); } void reset_high() { reset(find_high()); } - void flip() { Bits = ~Bits; } + void flip() { Bits = (~Bits & valuable_bits(bits_num)); } void flip(id<1> id) { operator[](id).flip(); } bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; } @@ -177,11 +179,13 @@ struct sub_group_mask { sub_group_mask &operator^=(const sub_group_mask &rhs) { Bits ^= rhs.Bits; + Bits &= valuable_bits(bits_num); return *this; } sub_group_mask &operator<<=(size_t pos) { Bits <<= pos; + Bits &= valuable_bits(bits_num); return *this; } @@ -239,6 +243,9 @@ struct sub_group_mask { sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) { assert(bits_num <= max_bits); } + inline uint32_t valuable_bits(size_t bn) const { + return static_cast((1ULL << bn) - 1ULL); + } uint32_t Bits; // Number of valuable bits size_t bits_num; diff --git a/sycl/test/extensions/sub_group_mask.cpp b/sycl/test/extensions/sub_group_mask.cpp index 4e14729611235..2af2673ab5c8a 100644 --- a/sycl/test/extensions/sub_group_mask.cpp +++ b/sycl/test/extensions/sub_group_mask.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out // RUN: %t.out //==-------- sub_group_mask.cpp - SYCL sub-group mask test -----------------==// @@ -13,110 +13,124 @@ #include int main() { - auto g = sycl::detail::Builder::createSubGroupMask< - sycl::ext::oneapi::sub_group_mask>(0, 32); - assert(g.none() && !g.any() && !g.all()); - assert(g[10] == false); // reference::operator[](id) const; - g[10] = true; // reference::operator=(bool); - assert(g[10] == true); - g[11] = g[10]; // reference::operator=(reference) reference::operator[](id); - assert(g[10].flip() == false); // reference::flip() - assert(~g[10] == true); // refernce::operator~() - assert(g[10] == false); - assert(g[11] == true); - assert(g.test(10) == false && g.test(11) == true); - g.set(30, 1); - g.set(11, 0); - g.set(23, 1); - assert(!g.none() && g.any() && !g.all()); + for (size_t sgsize = 32; sgsize > 4; sgsize /= 2) { + std::cout << "Running test for sub-group size = " << sgsize << std::endl; + auto g = sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(0, sgsize); + assert(g.none() && !g.any() && !g.all()); + assert(g[5] == false); // reference::operator[](id) const; + g[5] = true; // reference::operator=(bool); + assert(g[5] == true); + g[6] = g[5]; // reference::operator=(reference) reference::operator[](id); + assert(g[5].flip() == false); // reference::flip() + assert(~g[5 % sgsize] == true); // refernce::operator~() + assert(g[5 % sgsize] == false); + assert(g[6 % sgsize] == true); + assert(g.test(5 % sgsize) == false && g.test(6 % sgsize) == true); + g.set(3 % sgsize, 1); + g.set(6 % sgsize, 0); + g.set(2 % sgsize, 1); + assert(!g.none() && g.any() && !g.all()); - assert(g.count() == 2); - assert(g.find_low() == 23); - assert(g.find_high() == 30); - assert(g.size() == 32); + assert(g.count() == 2); + assert(g.find_low() == 2 % sgsize); + assert(g.find_high() == 3 % sgsize); + assert(g.size() == sgsize); - g.reset(); - assert(g.none() && !g.any() && !g.all()); - assert(g.find_low() == g.size() && g.find_high() == g.size()); - g.set(); - assert(!g.none() && g.any() && g.all()); - assert(g.find_low() == 0 && g.find_high() == 31); - g.flip(); - assert(g.none() && !g.any() && !g.all()); + g.reset(); + assert(g.none() && !g.any() && !g.all()); + assert(g.find_low() == g.size() && g.find_high() == g.size()); + g.set(); + assert(!g.none() && g.any() && g.all()); + assert(g.find_low() == 0 && g.find_high() == 31 % sgsize); + g.flip(); + assert(g.none() && !g.any() && !g.all()); - g.flip(13); - g.flip(23); - g.flip(29); - auto b = g; - assert(b == g && !(b != g)); - g.flip(31); - assert(g.find_high() == 31); - assert(b.find_high() == 29); - assert(b != g && !(b == g)); - b.flip(31); - assert(b == g && !(b != g)); - b = g >> 1; - assert(b[12] && b[22] && b[28] && b[30]); - b <<= 1; - assert(b == g); - g ^= ~b; - assert(!g.none() && g.any() && g.all()); - assert((g | ~g).all()); - assert((g & ~g).none()); - assert((g ^ ~g).all()); - b.reset_low(); - b.reset_high(); - assert(!b[13] && b[23] && b[29] && !b[31]); - b.insert_bits(0x01020408); - assert(b[24] && b[17] && b[10] && b[3]); - b <<= 13; - assert(!b[24] && !b[17] && !b[10] && !b[3] && b[30] && b[23] && b[16]); - b.insert_bits((char)0b01010101, 18); - assert(b[18] && b[20] && b[22] && b[24] && b[30] && !b[23] && b[16]); - b[3] = true; - b.insert_bits(sycl::marray{1, 2, 4, 8, 16, 32, 64, 128}, 5); - assert(!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[3] && - b[5] && b[14] && b[23]); - char r, rbc; - const auto b_const{b}; - b.extract_bits(r); - b_const.extract_bits(rbc); - assert(r == 0b00101000); - assert(rbc == 0b00101000); - long r2 = -1, r2bc = -1; - b.extract_bits(r2, 16); - b_const.extract_bits(r2bc, 16); - assert(r2 == 128); - assert(r2bc == 128); + g.flip(2); + g.flip(3); + g.flip(7); + auto b = g; + assert(b == g && !(b != g)); + g.flip(7); + assert(g.find_high() == 3 % sgsize); + assert(b.find_high() == 7 % sgsize); + assert(b != g && !(b == g)); + g.flip(7); + assert(b == g && !(b != g)); + b = g >> 1; + assert(b[1] && b[2] && b[6]); + b <<= 1; + assert(b == g); + g ^= ~b; + assert(!g.none() && g.any() && g.all()); + assert((g | ~g).all()); + assert((g & ~g).none()); + assert((g ^ ~g).all()); + b.reset_low(); + b.reset_high(); + assert(!b[2] && b[3] && !b[7]); + b.insert_bits(0x01020408); + assert(((b[24] && b[17]) || sgsize < 32) && (b[10] || sgsize < 16) && b[3]); + b <<= 10; + assert(((!b[24] && !b[17] && b[27] && b[20]) || sgsize < 32) && + ((!b[10] && b[13]) || sgsize < 16) && !b[3]); + b.insert_bits((char)0b01010101, 6); + assert(b[6] && ((b[8] && b[10] && b[12] && !b[13]) || sgsize < 16)); + b[3] = true; + b.insert_bits(sycl::marray{1, 2, 4, 8, 16, 32, 64, 128}, 5); + assert( + ((!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[23]) || + sgsize < 32) && + b[3] && b[5] && (b[14] || sgsize < 16)); + b.flip(14); + b.flip(23); + char r, rbc; + const auto b_const{b}; + b.extract_bits(r); + b_const.extract_bits(rbc); + assert(r == 0b00101000); + assert(rbc == 0b00101000); + long r2 = -1, r2bc = -1; + b.extract_bits(r2, 3); + b_const.extract_bits(r2bc, 3); + assert(r2 == 5); + assert(r2bc == 5); - b[31] = true; - const auto b_const2{b}; - sycl::marray r3{-1}, r3bc{-1}; - b.extract_bits(r3, 14); - b_const2.extract_bits(r3bc, 14); - assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]); - assert(r3bc[0] == 1 && r3bc[1] == 2 && r3bc[2] == 2 && !r3bc[3] && !r3bc[4] && - !r3bc[5]); - int ibits = 0b1010101010101010101010101010101; - b.insert_bits(ibits); - for (size_t i = 0; i < 32; i++) { - assert(b[i] != (bool)(i % 2)); + b.insert_bits((uint32_t)0x08040201); + const auto b_const2{b}; + sycl::marray r3{-1}, r3bc{-1}; + b.extract_bits(r3); + b_const2.extract_bits(r3bc); + assert(r3[0] == 1 && r3[1] == (sgsize > 8 ? 2 : 0) && + r3[2] == (sgsize > 16 ? 4 : 0) && r3[3] == (sgsize > 16 ? 8 : 0) && + !r3[4] && !r3[5]); + assert(r3bc[0] == 1 && r3bc[1] == (sgsize > 8 ? 2 : 0) && + r3bc[2] == (sgsize > 16 ? 4 : 0) && + r3bc[3] == (sgsize > 16 ? 8 : 0) && !r3bc[4] && !r3bc[5]); + int ibits = 0b1010101010101010101010101010101; + b.insert_bits(ibits); + for (size_t i = 0; i < sgsize; i++) { + assert(b[i] != (bool)(i % 2)); + } + short sbits = 0b0111011101110111; + b.insert_bits(sbits, 7); + b.extract_bits(ibits); + assert(ibits == + (0b1010101001110111011101111010101 & ((1ULL << sgsize) - 1ULL))); + sbits = 0b1100001111000011; + b.insert_bits(sbits, 23); + b.extract_bits(ibits); + if (sgsize >= 32) { + int64_t lbits = -1; + b.extract_bits(lbits, 33); + assert(lbits == 0); + lbits = -1; + b.extract_bits(lbits, 5); + assert(lbits == + (0b111000011011101110111011110 & ((1ULL << sgsize) - 1ULL))); + lbits = -1; + b.insert_bits(lbits); + assert(b.all()); + } } - short sbits = 0b0111011101110111; - b.insert_bits(sbits, 7); - b.extract_bits(ibits); - assert(ibits == 0b1010101001110111011101111010101); - sbits = 0b1100001111000011; - b.insert_bits(sbits, 23); - b.extract_bits(ibits); - assert(ibits == 0b11100001101110111011101111010101); - int64_t lbits = -1; - b.extract_bits(lbits, 33); - assert(lbits == 0); - lbits = -1; - b.extract_bits(lbits, 5); - assert(lbits == 0b111000011011101110111011110); - lbits = -1; - b.insert_bits(lbits); - assert(b.all()); }