diff --git a/test/AllReduceTests.cpp b/test/AllReduceTests.cpp index ac393dfe8..fe9579ec5 100644 --- a/test/AllReduceTests.cpp +++ b/test/AllReduceTests.cpp @@ -13,7 +13,7 @@ namespace RcclUnitTesting // Configuration std::vector const funcTypes = {ncclCollAllReduce}; - std::vector const dataTypes = {ncclFloat32}; + std::vector const dataTypes = {ncclFloat32, ncclFp8E4M3, ncclFp8E5M2}; std::vector const redOps = {ncclSum}; std::vector const roots = {0}; std::vector const numElements = {393216, 384}; diff --git a/test/common/PtrUnion.cpp b/test/common/PtrUnion.cpp index 7ed1558f1..facf60b34 100644 --- a/test/common/PtrUnion.cpp +++ b/test/common/PtrUnion.cpp @@ -148,7 +148,9 @@ namespace RcclUnitTesting for (int i = 0; i < numElements; i++) { - int valueI = (globalRank + i) % 256; + // Due to floating-point math not being commutative, the ordering in which ranks are added will matter. + // For lower-precision data types, we initialize all ranks to the same value to avoid this + int valueI = (dataType == ncclFp8E4M3 || dataType == ncclFp8E5M2)? (i % 16) :(globalRank + i) % 256; double valueF = 1.0L/((double)valueI+1.0L); temp.Set(dataType, i, valueI, valueF); }