diff --git a/libclc/ptx-nvidiacl/libspirv/SOURCES b/libclc/ptx-nvidiacl/libspirv/SOURCES index 7d30776ac0906..94e9045522851 100644 --- a/libclc/ptx-nvidiacl/libspirv/SOURCES +++ b/libclc/ptx-nvidiacl/libspirv/SOURCES @@ -85,6 +85,7 @@ images/image_helpers.ll images/image.cl group/collectives_helpers.ll group/collectives.cl +group/group_ballot.cl atomic/atomic_add.cl atomic/atomic_and.cl atomic/atomic_cmpxchg.cl diff --git a/libclc/ptx-nvidiacl/libspirv/group/collectives.cl b/libclc/ptx-nvidiacl/libspirv/group/collectives.cl index d31adee1005b4..e859f262a1ed7 100644 --- a/libclc/ptx-nvidiacl/libspirv/group/collectives.cl +++ b/libclc/ptx-nvidiacl/libspirv/group/collectives.cl @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "membermask.h" + #include #include diff --git a/libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl b/libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl new file mode 100644 index 0000000000000..33285028b7b39 --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "membermask.h" + +#include +#include + +_CLC_DEF _CLC_CONVERGENT __clc_vec4_uint32_t +_Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) { + // only support subgroup for now + if (flag != Subgroup) { + __builtin_trap(); + __builtin_unreachable(); + } + + // prepare result, we only support the ballot operation on 32 threads maximum + // so we only need the first element to represent the final mask + __clc_vec4_uint32_t res; + res[1] = 0; + res[2] = 0; + res[3] = 0; + + // compute thread mask + unsigned threads = __clc__membermask(); + + // run the ballot operation + res[0] = __nvvm_vote_ballot_sync(threads, predicate); + + return res; +} diff --git a/libclc/ptx-nvidiacl/libspirv/group/membermask.h b/libclc/ptx-nvidiacl/libspirv/group/membermask.h new file mode 100644 index 0000000000000..a083a3b5d75d6 --- /dev/null +++ b/libclc/ptx-nvidiacl/libspirv/group/membermask.h @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTX_NVIDIACL_MEMBERMASK_H +#define PTX_NVIDIACL_MEMBERMASK_H + +#include +#include + +_CLC_DEF _CLC_CONVERGENT uint __clc__membermask(); + +#endif