diff --git a/examples/BuddyMatmul/batchmatmul-vectorization.mlir b/examples/BuddyMatmul/batchmatmul-vectorization.mlir new file mode 100644 index 000000000..8d3b30c44 --- /dev/null +++ b/examples/BuddyMatmul/batchmatmul-vectorization.mlir @@ -0,0 +1,135 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + // CMK * CKN -> CMN + func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %vl_step = arith.constant 32 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.splat %cst : vector<32xf32> + %dim = memref.dim %arg0, %c0 : memref + %dim_1 = memref.dim %arg0, %c1 : memref + %dim_2 = memref.dim %arg1, %c1 : memref + %dim_3 = memref.dim %arg1, %c2 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %dim_3_upbound_tmp = arith.subi %dim_3, %vl_step : index + %dim_3_upbound = arith.addi %dim_3_upbound_tmp, %c1 : index + + %t_start = call @rtclock() : () -> f64 + affine.for %arg3 = %c0 to %dim { // C + affine.prefetch %arg0[%arg3, %dim_1, %dim_2], read, locality<3>, data : memref + affine.for %arg4 = %c0 to %dim_1 { // M + // Perform the vectorization body. + %iter_idx = scf.for %arg5 = %c0 to %dim_3_upbound + step %vl_step iter_args(%iter_init = %c0) -> (index) { // N + %1 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> + %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 + iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K + %5 = memref.load %arg0[%arg3, %arg4, %arg6] : memref + %6 = vector.broadcast %5 : f32 to vector<32xf32> + %4 = vector.load %arg1[%arg3, %arg6, %arg5] : memref, vector<32xf32> + %8 = vector.fma %6, %4, %iter_vec0 : vector<32xf32> + scf.yield %8 : vector<32xf32> + } + vector.store %iter_vec, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> + %arg5_next = arith.addi %arg5, %vl_step : index + scf.yield %arg5_next : index + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %dim_3, %iter_idx : index + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { + %mask = vector.create_mask %tail_size : vector<32xi1> + %1 = vector.maskedload %arg2[%arg3, %arg4, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 + iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K + %5 = vector.maskedload %arg1[%arg3, %arg6, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %6 = memref.load %arg0[%arg3, %arg4, %arg6] : memref + %7 = vector.broadcast %6 : f32 to vector<32xf32> + %9 = vector.fma %7, %5, %iter_vec0 : vector<32xf32> + scf.yield %9 : vector<32xf32> + } + vector.maskedstore %arg2[%arg3, %arg4, %iter_idx], %mask, %iter_vec : memref, vector<32xi1>, vector<32xf32> + } + } + } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + vector.print %time : f64 + return + } + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2] : memref + } + } + } + return %0 : memref + } + + func.func @main(){ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c576 = arith.constant 576 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %f0 = arith.constant 0.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + + %m0 = call @alloc_f32(%c1, %c1, %c576, %f2) : (index, index, index, f32) -> memref + %m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref + %m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [3456{{(, 3456)*}}] + call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () + + %m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref + %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref + %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [6144{{(, 6144)*}}] + call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () + + return + } +} diff --git a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir index ca132cb86..aadebbd18 100644 --- a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir +++ b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir @@ -35,7 +35,7 @@ func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg // Print timings. vector.print %time : f64 - + return } diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile index 1242f2246..ffc1ab24b 100644 --- a/examples/BuddyMatmul/makefile +++ b/examples/BuddyMatmul/makefile @@ -20,6 +20,28 @@ MTRIPLE := x86_64-apple-darwin endif linalg-batchmatmul-f32-run: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batchmatmul-vectorization-lower: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -batchmatmul-optimize \ + -o ./log.mlir + +linalg-batchmatmul-f32-vectorization-run: @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ -batchmatmul-optimize \ -convert-linalg-to-affine-loops \ @@ -100,3 +122,20 @@ linalg-matmul-transpose-b-f32-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +batchmatmul-vectorization-run: + @${BUDDY_OPT} ./batchmatmul-vectorization.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 6cedaa165..318fd5752 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -18,6 +18,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -25,6 +26,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/ValueRange.h" #include "llvm/ADT/ArrayRef.h" #include @@ -32,6 +34,10 @@ #include #include #include +#include +#include +#include +#include #include #include #include @@ -40,7 +46,6 @@ using namespace mlir; using namespace vector; -using namespace affine; //===----------------------------------------------------------------------===// // Rewrite Pattern @@ -51,292 +56,165 @@ namespace { class BatchMatMulOptimizePattern : public ConversionPattern { public: explicit BatchMatMulOptimizePattern(MLIRContext *context, - int64_t affineVectorSizeParam) + int64_t vecSizeParam) : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, context) { - affineVectorSize = affineVectorSizeParam; + vecSize = vecSizeParam; } LogicalResult matchAndRewrite(Operation *op, ArrayRef /*operands*/, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto ctx = op->getContext(); // Retrieve input tensors A, B, and C. Value A = op->getOperand(0); Value B = op->getOperand(1); Value C = op->getOperand(2); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({vecSize}, i1); // Acquire the element type of input tensors. Type elementType = A.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({vecSize}, elementType); - // Define constants. - const Value zeroIndex = - rewriter.create(loc, rewriter.getIndexAttr(0)); const AffineExpr d0 = rewriter.getAffineDimExpr(0); const AffineExpr d1 = rewriter.getAffineDimExpr(1); const AffineExpr d2 = rewriter.getAffineDimExpr(2); - const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); - const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); - const Value zeroElementType = rewriter.create( + // Define constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value vl_step = rewriter.create(loc, vecSize); + const Value zero = rewriter.create( loc, rewriter.getZeroAttr(elementType)); - const Value zeroElementTypeVec = rewriter.create( - loc, VectorType::get({affineVectorSize}, elementType), zeroElementType); - - // Get dimensions of input tensors. - Value batch = rewriter.create(loc, A, 0); - Value aRow = rewriter.create(loc, A, 1); - Value bCol = rewriter.create(loc, B, 2); - Value bRow = rewriter.create(loc, B, 1); - - // Calculate the length of the tail, which might not fit in a vector. - Value tailLength = rewriter.create( - loc, AffineMap::get(1, 0, d0 % affineVectorSize), ValueRange{bCol}); - // Generate a mask vector based on the tail length. - Value maskVector = rewriter.create( - loc, VectorType::get({affineVectorSize}, rewriter.getI1Type()), - ValueRange{tailLength}); + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); - SmallVector reducedValues = llvm::to_vector<4>( - llvm::map_range(ArrayRef{}, - [](const LoopReduction &red) { return red.value; })); - - // Apply the column of matrix B. - Value appliedColOfB = rewriter.create( - loc, AffineMap::get(1, 0, d0.ceilDiv(affineVectorSize)), - ValueRange{bCol}); - - // Create the primary parallel batch level loop. - AffineParallelOp parallelBatchLoop = - rewriter.create( - loc, ValueRange(reducedValues).getTypes(), ValueRange{batch}, - ArrayRef{ - rewriter.getNamedAttr("lowerBoundsGroups", - rewriter.getI32TensorAttr({1})), - rewriter.getNamedAttr("upperBoundsGroups", - rewriter.getI32TensorAttr({1})), - rewriter.getNamedAttr( - "lowerBoundsMap", - AffineMapAttr::get(AffineMap::get(0, 0, {zeroAffine}, - rewriter.getContext()))), - rewriter.getNamedAttr("upperBoundsMap", - AffineMapAttr::get(AffineMap::get( - 1, 0, {d0}, rewriter.getContext()))), - rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), - rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); - - // Create the loop body for the parallel loop. - Block *loopBody = new Block(); - rewriter.setInsertionPointToStart(loopBody); - loopBody->addArgument(rewriter.getIndexType(), loc); - Value loopVarBatchIdx = loopBody->getArguments()[0]; - - // Prefetching data from tensor 'A' for better cache utilization. - rewriter.create( - loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), - ArrayRef{loopVarBatchIdx, aRow, bRow}, false, 3, true); + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, c0); + Value aRow = rewriter.create(loc, A, c1); + Value bCol = rewriter.create(loc, B, c2); + Value bRow = rewriter.create(loc, B, c1); + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBound_tmp = rewriter.create(loc, bCol, vl_step); + Value upperBound = rewriter.create(loc, upperBound_tmp, c1); affine::buildAffineLoopNest( - rewriter, loc, {zeroIndex}, {appliedColOfB}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarColOfB = ivRange.front(); - - // Compile time branch detection. - if (C.getType().cast().isDynamicDim(2) or - C.getType().cast().getDimSize(2) % affineVectorSize != - 0) { - - // Depending on the position, use either full vectors or tail - // vectors. - affine::AffineIfOp branchingOp = builder.create( - loc, - IntegerSet::get( - 1, 1, {d0 * -affineVectorSize + s0 - affineVectorSize}, - {false}), - ValueRange{loopVarColOfB, bCol}, true); - - // Branch handling full vector operations. - OpBuilder trueBranchBuilder = branchingOp.getThenBodyBuilder(); - affine::buildAffineLoopNest( - trueBranchBuilder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(3, 0, {d0, d1, d2 * affineVectorSize}, - rewriter.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfB, - loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, computedVec, C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - }); - }); - - // Branch handling operations on the tail. - OpBuilder falseBranchBuilder = branchingOp.getElseBodyBuilder(); - affine::buildAffineLoopNest( - falseBranchBuilder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value tailIdxColOfB = builder.create( - loc, AffineMap::get(1, 0, d0 * affineVectorSize), - ValueRange{loopVarColOfB}); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - ValueRange{loopVarBatchIdx, loopVarRowOfB, tailIdxColOfB}, - maskVector, zeroElementTypeVec); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - tailIdxColOfB}, - maskVector, zeroElementTypeVec); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, C, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - tailIdxColOfB}, - maskVector, computedVec); - }); - }); - } else { - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(3, 0, {d0, d1, d2 * affineVectorSize}, - rewriter.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfB, - loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, computedVec, C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - }); - }); - } + rewriter, loc, {c0}, {batch}, /*Step=*/1, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Prefetching data from tensor 'A' for better cache utilization. + builder.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, ctx), + ArrayRef{ivs[0], aRow, bRow}, false, 3, true); + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{aRow}, builder.getDimIdentityMap(), + /*Step=*/1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs0) { + auto iter_idx = builder.create( + loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv2, + ValueRange itrArgs0) { + Value cVec = builder.create( + loc, vectorTy, C, ValueRange{ivs[0], iv1, iv2}); + auto iter_vec = nestedBuilder.create( + nestedLoc, c0, bRow, /*Step=*/c1, ValueRange{cVec}, + [&](OpBuilder &builder, Location loc, Value iv3, + ValueRange itrArgs1) { + Value aValue = builder.create( + loc, elementType, A, + ValueRange{ivs[0], iv1, iv3}); + Value aVec = builder.create( + loc, vectorTy, aValue); + Value bVec = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], iv3, iv2}); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value computedVec; + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, itrArgs1[0]); + } else { + computedVec = builder.create( + loc, aVec, bVec, itrArgs1[0]); + } + builder.create(loc, computedVec); + }); + nestedBuilder.create( + nestedLoc, iter_vec.getResult(0), C, + ValueRange{ivs[0], iv1, iv2}); + Value idx = nestedBuilder.create( + nestedLoc, iv2, vl_step); + nestedBuilder.create(nestedLoc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iter_idx.getResult(0); + Value tailSize = builder.create(loc, bCol, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sgt, tailSize, c0); + // If the current column does not reach the tail. + builder.create( + loc, tailCond, [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value tailMask = builder.create( + loc, vectorMaskTy, tailSize); + Value maskedCVec = builder.create( + loc, vectorTy, C, ValueRange{ivs[0], iv1, idx}, + tailMask, passThroughVec); + auto iter_vec = builder.create( + loc, c0, bRow, /*Step=*/c1, ValueRange{maskedCVec}, + [&](OpBuilder &builder, Location loc, Value iv3, + ValueRange itrArgs1) { + Value aValue = builder.create( + loc, A, ValueRange{ivs[0], iv1, iv3}); + Value aVec = builder.create( + loc, vectorTy, aValue); + Value maskedBVec = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], iv3, idx}, + tailMask, passThroughVec); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value computedVec; + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, maskedBVec); + computedVec = builder.create( + loc, mulVec, itrArgs1[0]); + } else { + computedVec = builder.create( + loc, aVec, maskedBVec, itrArgs1[0]); + } + builder.create(loc, computedVec); + }); + builder.create( + loc, C, ValueRange{ivs[0], iv1, idx}, tailMask, + iter_vec.getResult(0)); + builder.create(loc); + }); + builder.create(loc); + }); }); - - rewriter.create(loc); - - // Finalize the loop and erase the original operation. - parallelBatchLoop.getRegion().push_back(loopBody); - rewriter.setInsertionPointAfter(parallelBatchLoop); - rewriter.eraseOp(op); return success(); } private: - int64_t affineVectorSize; + int64_t vecSize; }; } // end anonymous namespace @@ -355,8 +233,8 @@ class BatchMatMulOptimizePass StringRef getDescription() const final { return "BatchMatMul Optimization."; } BatchMatMulOptimizePass() = default; BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {} - explicit BatchMatMulOptimizePass(int64_t affineVectorSizeParam) { - affineVectorSize = affineVectorSizeParam; + explicit BatchMatMulOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; } void runOnOperation() override; @@ -366,9 +244,9 @@ class BatchMatMulOptimizePass affine::AffineDialect, VectorDialect>(); } - Option affineVectorSize{*this, "vector-size", - llvm::cl::desc("Affine Vector size."), - llvm::cl::init(64)}; + Option vecSize{*this, "vector-size", + llvm::cl::desc("Affine Vector size."), + llvm::cl::init(32)}; }; } // end anonymous namespace. @@ -384,7 +262,7 @@ void BatchMatMulOptimizePass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, affineVectorSize); + patterns.add(context, vecSize); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/tests/Conversion/batchmatmul-vectorization.mlir b/tests/Conversion/batchmatmul-vectorization.mlir new file mode 100644 index 000000000..d53a27b10 --- /dev/null +++ b/tests/Conversion/batchmatmul-vectorization.mlir @@ -0,0 +1,43 @@ +// RUN: buddy-opt -batchmatmul-optimize %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { +// CHECK-NEXT: affine.prefetch %arg0[%arg3, %dim_0, %dim_2], read, locality<3>, data : memref +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_0) { +// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %2 step %c32 iter_args(%arg6 = %c0) -> (index) { +// CHECK-NEXT: %6 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %7 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %6) -> (vector<32xf32>) { +// CHECK-NEXT: %9 = memref.load %arg0[%arg3, %arg4, %arg7] : memref +// CHECK-NEXT: %10 = vector.broadcast %9 : f32 to vector<32xf32> +// CHECK-NEXT: %11 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %12 = vector.fma %10, %11, %arg8 : vector<32xf32> +// CHECK-NEXT: scf.yield %12 : vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.store %7, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %8 = arith.addi %arg5, %c32 : index +// CHECK-NEXT: scf.yield %8 : index +// CHECK-NEXT: } +// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index +// CHECK-NEXT: %5 = arith.cmpi sgt, %4, %c0 : index +// CHECK-NEXT: scf.if %5 { +// CHECK-NEXT: %6 = vector.create_mask %4 : vector<32xi1> +// CHECK-NEXT: %7 = vector.maskedload %arg2[%arg3, %arg4, %3], %6, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %8 = scf.for %arg5 = %c0 to %dim_2 step %c1 iter_args(%arg6 = %7) -> (vector<32xf32>) { +// CHECK-NEXT: %9 = memref.load %arg0[%arg3, %arg4, %arg5] : memref +// CHECK-NEXT: %10 = vector.broadcast %9 : f32 to vector<32xf32> +// CHECK-NEXT: %11 = vector.maskedload %arg1[%arg3, %arg5, %3], %6, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %12 = vector.fma %10, %11, %arg6 : vector<32xf32> +// CHECK-NEXT: scf.yield %12 : vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %6, %8 : memref, vector<32xi1>, vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.batch_matmul + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +}