From 1463db4afcdb2ba9a4692faafa71f6ca8eaa888f Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 2 Oct 2023 10:30:50 -0400 Subject: [PATCH] [LLVMCPU] Add support for dynamic quantization + reassociation of grouped qmm MegaPR [LLVMCPU] Allow parallel tiling in LLVMCPUSplitReduction, tile reduction by 2 This commit enables tiling of parallel dimensions in LLVMCPUSplitReduction, as well as changing the tile size of the resulting reduction to 2. The latter change is an x86 specific optimization that allows targeting specific instructions through VectorContractCustomKernels. [LLVMCPU] Add support for vecmat cases in VectorContractCustomKernel This commit introduces some new functionality to VectorContractCustomKernels: 1. Matching for vecmat kernels that have 1D vector shapes 2. Support for `vector.contract` ops with split reduction dimensions 3. Ability to allow promoting smaller bitwidth inputs with `arith.extui` or `arith.extsi` before passing into the `llvm.inline_asm` op 4. Ability to specify explicit constraint strings per register input in a VectorContractCustomKernel 5. Support for `i4` and `i8` input types 6. New x86 AVX512VNNI i16xi16->i32 vecmat kernel with split reduction This commit also adds `vector.transfer_read` flattening patterns and VectorContractCustomKernel lowering patterns to LLVMCPUVectorLowering. [LLVMCPU] Add pass to breakdown subbyte `arith.extui` This pass breaks down `arith.extui` ops that have `i4` inputs into a sequence of `vector.shuffle->arith.andi->arith.shrui`. This avoids bad lowering of subbyte extends in x86 backend. This pass is somewhat specific to some work on vecmat VectorContractCustomKernels right now, and has some unique matchings. The pass also attempts to make use of AVX512 registers, so the vector size for the resulting IR is hardcoded as 512 bits. This needs to change before landing. This pass in general needs some refactoring before landing. [LLVMCPU] Add pass to fold away unit dimensions on `vector.contract` ops This pass folds away unit dimensions on `vector.contract` ops to get these ops into a form that is recognizable by the VectorContractCustomKernels patterns. This pass also hoists `vector.shape_cast` ops out of containing `scf.for` ops if possible when the shape cast operates on the accumulator of a `vector.contract` op. This pattern may be better off somewhere else, but for now it is here because the unit dim folding pattern can produce a hoistable `vector.shape_cast` op in cases with split reduction. [LLVMCPU] Add flag to restrict reassociated quantized matmul optimizations [LLVMCPU] Add additional Memref alias foldings [LLVMCPU] Simplify VectorContractCustomKernels x86 constraint codes, add new AVX512 kernel --- .../iree/compiler/Codegen/LLVMCPU/BUILD.bazel | 3 + .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 3 + .../LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp | 387 ++++++++++++++++++ .../LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp | 283 +++++++++++++ .../LLVMCPUFoldVectorContractUnitDims.cpp | 347 ++++++++++++++++ .../Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp | 56 ++- .../Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp | 40 ++ .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 33 +- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 20 +- .../iree/compiler/Codegen/LLVMCPU/Passes.td | 25 ++ .../LLVMCPU/VectorContractCustomKernels.cpp | 383 ++++++++++++++--- .../Dialect/Flow/Transforms/Passes.td | 6 +- 12 files changed, 1507 insertions(+), 79 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 900d220d7720..c82ce6892ba2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -53,8 +53,11 @@ iree_compiler_cc_library( "KernelDispatch.cpp", "LLVMCPUAssignConstantOrdinals.cpp", "LLVMCPUAssignImportOrdinals.cpp", + "LLVMCPUBreakDownSubbyteExtend.cpp", "LLVMCPUCheckIRBeforeLLVMConversion.cpp", "LLVMCPUEmitVectorizationRemarks.cpp", + "LLVMCPUFoldMemRefAliasOps.cpp", + "LLVMCPUFoldVectorContractUnitDims.cpp", "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", "LLVMCPULowerToUKernels.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 8d5feafc38db..bd0378980295 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -54,8 +54,11 @@ iree_cc_library( "KernelDispatch.cpp" "LLVMCPUAssignConstantOrdinals.cpp" "LLVMCPUAssignImportOrdinals.cpp" + "LLVMCPUBreakDownSubbyteExtend.cpp" "LLVMCPUCheckIRBeforeLLVMConversion.cpp" "LLVMCPUEmitVectorizationRemarks.cpp" + "LLVMCPUFoldMemRefAliasOps.cpp" + "LLVMCPUFoldVectorContractUnitDims.cpp" "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" "LLVMCPULowerToUKernels.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp new file mode 100644 index 000000000000..6d9b90384145 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp @@ -0,0 +1,387 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-breakdown-subbyte-extend" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +template +static Value shuffleMaskShift(PatternRewriter &rewriter, Location loc, + SmallVector shuffleInputs, + int64_t srcBitWidth, int64_t vectorSize) { + auto shuffleInType = llvm::cast(shuffleInputs[0].getType()); + auto shuffleResultType = + VectorType::get({vectorSize}, shuffleInType.getElementType()); + int64_t dstBitWidth = shuffleInType.getElementTypeBitWidth(); + T maskBase = (1u << srcBitWidth) - 1; + + SmallVector maskArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + maskArray[elemNum] = maskBase << (elemNum * srcBitWidth % dstBitWidth); + } + auto maskVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, maskArray)); + LDBG("maskVals: " << maskVals); + SmallVector shruiArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + shruiArray[elemNum] = elemNum * srcBitWidth % dstBitWidth; + } + auto shruiVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + + int64_t dstSize = vectorSize * shuffleInputs.size(); + auto newVectorType = + VectorType::get({dstSize}, shuffleResultType.getElementType()); + Value newVector = rewriter.create( + loc, newVectorType, rewriter.getZeroAttr(newVectorType)); + + for (auto shuffleIn : llvm::enumerate(shuffleInputs)) { + SmallVector shuffleArray(vectorSize); + for (int64_t elemNum = 0; elemNum < vectorSize; elemNum++) { + shuffleArray[elemNum] = + elemNum / (vectorSize / shuffleInType.getNumElements()); + } + Value shuffleResult = rewriter.create( + loc, shuffleIn.value(), shuffleIn.value(), shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value andResult = + rewriter.create(loc, shuffleResult, maskVals); + LDBG("andResult: " << andResult); + + Value shruiResult = + rewriter.create(loc, andResult, shruiVals); + LDBG("shruiResult: " << shruiResult); + + int64_t offset = shuffleIn.index() * vectorSize; + newVector = rewriter.create( + loc, shruiResult, newVector, offset, 1); + } + return newVector; +} + +static std::optional> +getLoadsForExtend(arith::ExtUIOp extOp) { + Value extSource = extOp.getIn(); + auto shapeCastOp = extSource.getDefiningOp(); + if (!shapeCastOp) { + return std::nullopt; + } + Value shapeCastSource = shapeCastOp.getSource(); + auto insertOp = shapeCastSource.getDefiningOp(); + if (!insertOp) { + return std::nullopt; + } + SmallVector loads; + while (insertOp) { + Value insert = insertOp.getSource(); + auto insertShapeCastOp = insert.getDefiningOp(); + if (!insertShapeCastOp) { + return std::nullopt; + } + auto loadOp = insertShapeCastOp.getSource().getDefiningOp(); + if (!loadOp) { + return std::nullopt; + } + loads.push_back(loadOp.getResult()); + insertOp = insertOp.getDest().getDefiningOp(); + } + return loads; +} + +struct BreakDownSubbyteExtend final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + + SmallVector sources{extOp.getIn()}; + if (auto loads = getLoadsForExtend(extOp)) { + sources = *loads; + } + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(dstElemBitwidth) || srcElemBitwidth != 4) + return failure(); + + if (dstElemBitwidth != 32 && dstElemBitwidth != 16) { + return failure(); + } + + int64_t vectorSizeBits = 512; + int64_t vectorSize = vectorSizeBits / dstElemBitwidth; + int64_t shuffleInputSizeBits = vectorSize * srcElemBitwidth; + int64_t shuffleInputSize = shuffleInputSizeBits / dstElemBitwidth; + auto shuffleInputType = + VectorType::get({shuffleInputSize}, extuiDstType.getElementType()); + Value shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + SmallVector shuffleInputs; + + for (int sourceIdx = 0; sourceIdx < sources.size(); sourceIdx++) { + Value source = sources[sourceIdx]; + VectorType sourceType = llvm::cast(source.getType()); + SmallVector sourceShape(sourceType.getShape()); + int64_t innerSize = sourceShape.back(); + if (!llvm::isPowerOf2_64(innerSize)) { + return failure(); + } + for (int64_t i = 0; i < sourceType.getNumElements() / innerSize; i++) { + SmallVector indices; + int64_t numElems = i; + SmallVector sourceOuterShape(sourceShape.begin(), + sourceShape.end() - 1); + for (int64_t size : llvm::reverse(sourceOuterShape)) { + indices.push_back(numElems % size); + numElems /= size; + } + std::reverse(indices.begin(), indices.end()); + + Value innerSlice; + if (indices.size()) { + innerSlice = rewriter.create(extOp.getLoc(), + source, indices); + } else { + innerSlice = source; + } + VectorType innerSliceType = + llvm::cast(innerSlice.getType()); + int64_t numExtractedBits = + innerSliceType.getNumElements() * srcElemBitwidth; + if (numExtractedBits / dstElemBitwidth < 1) { + LDBG("extract not big enough: " << numExtractedBits / + dstElemBitwidth); + return failure(); + } + auto bitCastType = VectorType::get({numExtractedBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, innerSlice); + LDBG("innerSlice: " << innerSlice); + // LDBG("bitCastResult: " << bitCastResult); + + if (numExtractedBits >= shuffleInputSizeBits) { + for (int64_t extractOffset = 0; + extractOffset < numExtractedBits / dstElemBitwidth; + extractOffset += shuffleInputSize) { + Value extractedSlice = + rewriter.create( + extOp.getLoc(), bitCastResult, extractOffset, + shuffleInputSize, 1); + shuffleInputs.push_back(extractedSlice); + LDBG("extractedSlice: " << extractedSlice); + // vector = + // rewriter.create(extOp.getLoc(), + // extractedSlice, vector, SmallVector{offset}, + // SmallVector{1}); + } + } else { + int64_t offset = + i * numExtractedBits / dstElemBitwidth % shuffleInputSize; + shuffleInput = rewriter.create( + extOp.getLoc(), bitCastResult, shuffleInput, + SmallVector{offset}, SmallVector{1}); + if (offset + numExtractedBits / dstElemBitwidth == shuffleInputSize) { + shuffleInputs.push_back(shuffleInput); + shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + } + } + } + } + + Value newVector; + if (dstElemBitwidth == 32) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } else if (dstElemBitwidth == 16) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } + rewriter.replaceOpWithNewOp(extOp, extuiDstType, + newVector); + + return success(); + } +}; + +struct BreakDownSubbyteExtendFlatten final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + LDBG("extuiSrcType: " << extuiSrcType); + LDBG("extuiDstType: " << extuiDstType); + + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(extuiSrcType.getNumElements())) + return failure(); + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + LDBG("srcElemBitwidth: " << srcElemBitwidth); + LDBG("dstElemBitwidth: " << dstElemBitwidth); + + int64_t numBits = srcElemBitwidth * extuiSrcType.getNumElements(); + if (numBits / dstElemBitwidth < 1) { + return failure(); + } + + VectorType flattenedType = VectorType::get({extuiSrcType.getNumElements()}, + extuiSrcType.getElementType()); + Value shapeCastFlatten = rewriter.create( + extOp.getLoc(), flattenedType, extOp.getIn()); + + auto bitCastType = VectorType::get({numBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, shapeCastFlatten); + LDBG("bitCastResult: " << bitCastResult); + + SmallVector shuffleArray(extuiDstType.getNumElements()); + for (int64_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shuffleArray[elemNum] = elemNum / (extuiDstType.getNumElements() / + bitCastType.getNumElements()); + } + + Value shuffleResult = rewriter.create( + extOp.getLoc(), bitCastResult, bitCastResult, shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value shapeCastUnflatten = rewriter.create( + extOp.getLoc(), extuiDstType, shuffleResult); + Value maskVals, shruiVals; + if (dstElemBitwidth == 32) { + int32_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else if (dstElemBitwidth == 16) { + int16_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else { + return failure(); + } + + Value andResult = rewriter.create( + extOp.getLoc(), shapeCastUnflatten, maskVals); + LDBG("andResult: " << andResult); + + rewriter.replaceOpWithNewOp(extOp, andResult, shruiVals); + + return success(); + } +}; + +struct LLVMCPUBreakDownSubbyteExtendPass final + : public LLVMCPUBreakDownSubbyteExtendBase< + LLVMCPUBreakDownSubbyteExtendPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // For the case when the innermost dimension of the src type is too small to + // fill a single element of the dst type. + // { + // RewritePatternSet patterns(context); + // patterns.add(context); + // vector::populateVectorShapeCastLoweringPatterns(patterns); + // if (failed(applyPatternsAndFoldGreedily(getOperation(), + // std::move(patterns)))) { + // return signalPassFailure(); + // } + // } + } +}; + +} // namespace + +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass() { + return std::make_unique(); +} + +void populateLLVMCPUBreakDownSubbyteExtendPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp new file mode 100644 index 000000000000..fc8c40dfb2e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp @@ -0,0 +1,283 @@ +//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Merges expand_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfCollapseShapeOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +static SmallVector +calculateExpandedAccessIndices(AffineMap affineMap, + const SmallVector &indices, Location loc, + PatternRewriter &rewriter) { + SmallVector indicesOfr(llvm::to_vector( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); + SmallVector expandedIndices; + for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, affineMap.getSubMap({i}), indicesOfr); + expandedIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return expandedIndices; +} + +static LogicalResult +resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + // The below implementation uses computeSuffixProduct method, which only + // allows int64_t values (i.e., static shape). Bail out if it has dynamic + // shapes. + if (!expandShapeOp.getResultType().hasStaticShape()) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + int64_t groupSize = groups.size(); + + // Construct the expression for the index value w.r.t to expand shape op + // source corresponding the indices wrt to expand shape op result. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + SmallVector dims(groupSize); + bindDimsList(ctx, MutableArrayRef{dims}); + AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + + /// Apply permutation and create AffineApplyOp. + SmallVector dynamicIndices(groupSize); + for (int64_t i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; + + // Creating maximally folded and composd affine.apply composes better with + // other transformations without interleaving canonicalization passes. + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/groupSize, + /*numSymbols=*/0, srcIndexExpr), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return success(); +} + +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + int64_t cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (ArrayRef groups : collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + int64_t groupSize = groups.size(); + + // Calculate suffix product for all collapse op source dimension sizes. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseShapeOp.getViewSource().getType()).getRank(); + for (int64_t i = 0; i < srcRank; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + +/// Helpers to access the memref operand for each op. +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.getMemref(); +} + +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } + +template +LogicalResult LLVMCPULoadOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), expandShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LLVMCPULoadOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(loadOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), collapseShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +void populateLLVMCPUFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { + patterns.add, + LLVMCPULoadOpOfCollapseShapeOpFolder>( + patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct LLVMCPUFoldMemRefAliasOpsPass final + : public LLVMCPUFoldMemRefAliasOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void LLVMCPUFoldMemRefAliasOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patterns); + populateLLVMCPUFoldMemRefAliasOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir \ No newline at end of file diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp new file mode 100644 index 000000000000..ac01d81ce392 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp @@ -0,0 +1,347 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- LLVMCPUFoldVectorContractUnitDims.cpp - Pass to fold unit dims of +// vector.contract ops -===// +// +// Patterns to fold away unit dimensions on `vector.contract` ops +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-unit-reduction-dims" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { + +// Given a `vector.contract` op and a set of indices to fold, this op rewrites +// the `vector.contract` op with surrounding `vector.shape_cast` ops to fold +// away the indicated indices. +static FailureOr +dropFoldableUnitIndices(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector foldIndices) { + SmallVector contractShape = *contractOp.getShapeForUnroll(); + SmallVector iteratorTypes = + contractOp.getIteratorTypesArray(); + auto indexingMaps = contractOp.getIndexingMapsArray(); + SmallVector> dstShapes; + SmallVector> dstExprs; + SmallVector inputs( + {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}); + llvm::SetVector foldableDims; + for (int64_t dim : foldIndices) + foldableDims.insert(dim); + + for (AffineMap map : indexingMaps) { + SmallVector dstShape; + SmallVector dstExpr; + for (const auto &expr : enumerate(map.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (!foldableDims.contains(dimExpr.getPosition())) { + dstShape.push_back(contractShape[dimExpr.getPosition()]); + unsigned numSkipped = 0; + for (int64_t ind : foldIndices) { + if (dimExpr.getPosition() > ind) { + numSkipped++; + } + } + dstExpr.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition() - numSkipped)); + } + } else { + return failure(); + } + } + dstShapes.push_back(dstShape); + dstExprs.push_back(dstExpr); + } + + SmallVector newInputs; + SmallVector newIndexingMaps; + SmallVector newIteratorTypes; + for (auto iter : enumerate(iteratorTypes)) { + if (!foldableDims.contains(iter.index())) { + newIteratorTypes.push_back(iter.value()); + } + } + + for (int i = 0; i < 3; i++) { + // Shape unchanged + if (dstShapes[i].size() == indexingMaps[i].getResults().size()) { + newInputs.push_back(inputs[i]); + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newIndexingMaps.push_back(newIndexingMap); + continue; + } + if (dstShapes[i].size() == 0) { + return failure(); + } + VectorType inputVecType = llvm::cast(inputs[i].getType()); + VectorType dstType = + VectorType::get(dstShapes[i], inputVecType.getElementType()); + + Value result; + auto extsiop = inputs[i].getDefiningOp(); + auto extuiop = inputs[i].getDefiningOp(); + if (!extsiop && !extuiop) { + result = rewriter.create(contractOp.getLoc(), + dstType, inputs[i]); + } else { + Value extIn = extsiop ? extsiop.getIn() : extuiop.getIn(); + VectorType extInType = llvm::dyn_cast(extIn.getType()); + VectorType shapeCastOutType = + VectorType::get(dstType.getShape(), extInType.getElementType()); + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), shapeCastOutType, extIn); + result = extsiop ? rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult() + : rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult(); + } + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newInputs.push_back(result); + newIndexingMaps.push_back(newIndexingMap); + } + auto newContract = + rewriter + .create( + contractOp.getLoc(), newInputs[0], newInputs[1], newInputs[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( + newIteratorTypes, + [&](vector::IteratorType t) -> mlir::Attribute { + return vector::IteratorTypeAttr::get(rewriter.getContext(), + t); + })))) + .getResult(); + return newContract; +} + +// This pattern matches on a `vector.contract` op with unit size dimensions, and +// folds these dimensions away +class DropVectorContractUnitDims final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + LDBG("vector.contract op:\n" << contractOp); + VectorType outputType = + llvm::dyn_cast(contractOp.getAcc().getType()); + if (!outputType) { + return failure(); + } + + auto iteratorTypes = contractOp.getIteratorTypesArray(); + SmallVector contractDims = *contractOp.getShapeForUnroll(); + unsigned numParallel = 0; + unsigned numReduction = 0; + SmallVector unitParallelDims; + SmallVector unitReductionDims; + SmallVector foldableDims; + for (auto size : enumerate(contractDims)) { + if (iteratorTypes[size.index()] == vector::IteratorType::parallel) { + numParallel++; + if (size.value() == 1) { + unitParallelDims.push_back(size.index()); + } + } else { + numReduction++; + if (size.value() == 1) { + unitReductionDims.push_back(size.index()); + } + } + } + if (numReduction && numReduction == unitReductionDims.size()) { + foldableDims.append(unitReductionDims.begin(), + unitReductionDims.end() - 1); + } else { + foldableDims.append(unitReductionDims.begin(), unitReductionDims.end()); + } + if (numParallel && numParallel == unitParallelDims.size()) { + foldableDims.append(unitParallelDims.begin() + 1, unitParallelDims.end()); + } else { + foldableDims.append(unitParallelDims.begin(), unitParallelDims.end()); + } + if (!foldableDims.size()) { + return failure(); + } + + FailureOr maybeNewContract = + dropFoldableUnitIndices(rewriter, contractOp, foldableDims); + if (failed(maybeNewContract)) { + return failure(); + } + Value newContract = maybeNewContract.value(); + LDBG("Replaced vector.contract:\n" << newContract); + + VectorType newOutputType = + llvm::dyn_cast(newContract.getType()); + if (outputType != newOutputType) { + // Reshape output of new vector.contract if needed + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), outputType, newContract); + rewriter.replaceOp(contractOp, shapeCastResult); + } else { + rewriter.replaceOp(contractOp, newContract); + } + + return success(); + } +}; + +// This pattern matches on a sequence of +// `vector.shape_cast->vector.contract->vector.shape_cast` within an `scf.for` +// op, where the shape cast ops are casting an argument of the `scf.for` op and +// the yielded result of the `scf.for` op. Once matched, the `vector.shape_cast` +// ops are hoisted out of the `scf.for` op. +class HoistShapeCastOutOfSCFFor final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + LDBG("forOp:\n" << forOp); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + std::optional> + hoistableShapeCast = std::nullopt; + int initArgIdx; + for (Value result : yieldOp.getOperation()->getOperands()) { + auto outputShapeCastOp = result.getDefiningOp(); + if (!outputShapeCastOp) { + continue; + } + LDBG("outputShapeCastOp:\n" << outputShapeCastOp); + auto contractOp = + outputShapeCastOp.getSource().getDefiningOp(); + if (!contractOp) { + continue; + } + LDBG("contractOp:\n" << contractOp); + Value acc = contractOp.getAcc(); + auto inputShapeCastOp = acc.getDefiningOp(); + if (!inputShapeCastOp) { + continue; + } + LDBG("inputShapeCastOp:\n" << inputShapeCastOp); + Value input = inputShapeCastOp.getSource(); + auto blockArg = dyn_cast(input); + if (!blockArg) { + continue; + } + LDBG("blockArg:\n" << blockArg); + hoistableShapeCast = std::make_pair(inputShapeCastOp, outputShapeCastOp); + initArgIdx = blockArg.getArgNumber() - 1; + } + + if (!hoistableShapeCast) { + return failure(); + } + vector::ShapeCastOp inSC = hoistableShapeCast->first; + vector::ShapeCastOp outSC = hoistableShapeCast->second; + SmallVector forOpInitArgs = forOp.getInitArgs(); + Value source = forOpInitArgs[initArgIdx]; + Value sourceSC = + rewriter + .create(forOp.getLoc(), inSC.getType(), source) + .getResult(); + forOpInitArgs[initArgIdx] = sourceSC; + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), forOpInitArgs); + LDBG("newForOp:\n" << newForOp); + rewriter.mergeBlocks(forOp.getBody(), newForOp.getBody(), + newForOp.getBody()->getArguments()); + auto newYieldOp = cast(newForOp.getBody()->getTerminator()); + LDBG("newYieldOp:\n" << newYieldOp); + SmallVector newForOpResults = + newYieldOp.getOperation()->getOperands(); + int contractResultIndex; + for (auto result : llvm::enumerate(newForOpResults)) { + if (result.value() == outSC.getResult()) { + newForOpResults[result.index()] = outSC.getSource(); + contractResultIndex = result.index(); + } + } + rewriter.updateRootInPlace(newYieldOp, [&]() { + newYieldOp.getOperation()->setOperands(newForOpResults); + }); + LDBG("newForOp with body:\n" << newForOp); + SmallVector newResults = newForOp.getResults(); + Value hoistedOutputShapeCast = + rewriter + .create(forOp.getLoc(), outSC.getType(), + newResults[contractResultIndex]) + .getResult(); + LDBG("hoistedOutputShapeCast:\n" << hoistedOutputShapeCast); + newResults[contractResultIndex] = hoistedOutputShapeCast; + rewriter.replaceOp(forOp, newResults); + + return success(); + } +}; + +namespace { +struct LLVMCPUFoldVectorContractUnitDimsPass + : public LLVMCPUFoldVectorContractUnitDimsBase< + LLVMCPUFoldVectorContractUnitDimsPass> { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; +} // namespace + +void LLVMCPUFoldVectorContractUnitDimsPass::runOnOperation() { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet foldUnitDimsPatterns(context); + foldUnitDimsPatterns + .add(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(foldUnitDimsPatterns)))) { + return signalPassFailure(); + } +} + +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass() { + return std::make_unique(); +} + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp index 722db723c594..6ece8dcdf89f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp @@ -38,8 +38,9 @@ namespace { /// TODO: support named ops, numInputs > 1, and modify lastDim check below /// accordingly. If fpReductionReordering is not enabled by default, it must /// be an integer or index type to proceed to allow associative reordering. -LogicalResult splitReductionPrecondition(Operation *op, - bool fpReductionReordering) { +LogicalResult +splitReductionPrecondition(Operation *op, bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { linalg::LinalgOp linalgOp = cast(op); if (!linalgOp.hasTensorSemantics()) { @@ -63,7 +64,11 @@ LogicalResult splitReductionPrecondition(Operation *op, LLVM_DEBUG(llvm::dbgs() << "is not a generic op\n"); return failure(); } - if (linalgOp.getNumDpsInputs() != 1) { + if (enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() > 2) { + LLVM_DEBUG(llvm::dbgs() << "doesn't have at most 2 inputs\n"); + return failure(); + } + if (!enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() != 1) { LLVM_DEBUG(llvm::dbgs() << "doesn't have exactly 1 input\n"); return failure(); } @@ -102,8 +107,10 @@ LogicalResult splitReductionPrecondition(Operation *op, /// Converts an inner-reduction into outer reduction + inner-parallel dimension, /// followed by simple inner reduction. -LogicalResult splitReductionImpl(Operation *op, int64_t size, +LogicalResult splitReductionImpl(Operation *op, SmallVector tileSizes, + bool enableQuantizedMatmulReassociation, RewriterBase &rewriter) { + int64_t size = tileSizes.back(); IRRewriter::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(op); linalg::LinalgOp linalgOp = cast(op); @@ -119,8 +126,19 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, auto numLoops = linalgOp.getNumLoops(); // 1) Tile to extract a single vector-length array. - SmallVector tileSizesSVFirst(numLoops, - rewriter.getIndexAttr(1)); + SmallVector tileSizesSVFirst; + if (enableQuantizedMatmulReassociation) { + for (auto &s : tileSizes) { + if (!s) { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(1)); + } else { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(s)); + } + } + } else { + tileSizesSVFirst = + SmallVector(numLoops, rewriter.getIndexAttr(1)); + } tileSizesSVFirst[numLoops - 1] = rewriter.getIndexAttr(0); auto options = scf::SCFTilingOptions().setTileSizes(tileSizesSVFirst); FailureOr tileResFirst = scf::tileUsingSCFForOp( @@ -147,7 +165,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, rewriter.getIndexAttr(0)); // The reduction happens only in the penultimate dimension, which we now // tile. - tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + if (enableQuantizedMatmulReassociation) { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(2); + } else { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + } options = scf::SCFTilingOptions().setTileSizes(tileSizesSV); FailureOr tileRes = scf::tileUsingSCFForOp( rewriter, cast(splitRes->splitLinalgOp.getOperation()), @@ -164,8 +186,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, class LLVMCPUSplitReductionPass : public LLVMCPUSplitReductionBase { public: - LLVMCPUSplitReductionPass(bool fpReductionReordering) { + LLVMCPUSplitReductionPass(bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { this->enableFpReductionReordering = fpReductionReordering; + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -183,8 +208,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); }); for (auto genericOp : candidates) { LLVM_DEBUG(llvm::dbgs() << "candidate: " << genericOp << "\n"); - if (failed(splitReductionPrecondition(genericOp, - enableFpReductionReordering))) { + if (failed( + splitReductionPrecondition(genericOp, enableFpReductionReordering, + enableQuantizedMatmulReassociation))) { continue; } @@ -208,8 +234,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { "skip SplitReduction"); continue; } - int64_t size = reductionSizes.back(); - if (failed(splitReductionImpl(genericOp, size, rewriter))) { + if (failed(splitReductionImpl(genericOp, reductionSizes, + enableQuantizedMatmulReassociation, + rewriter))) { return signalPassFailure(); } } @@ -217,9 +244,10 @@ void LLVMCPUSplitReductionPass::runOnOperation() { } // namespace std::unique_ptr> -createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering) { +createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering, + const bool enableQuantizedMatmulReassociation) { return std::make_unique( - enableFpReductionReordering); + enableFpReductionReordering, enableQuantizedMatmulReassociation); } } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp index 42b46756c25d..ff4af02a3e26 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp @@ -46,6 +46,8 @@ class LLVMCPUVectorLoweringPass LLVMCPUVectorLoweringPass(const LLVMCPUVectorLoweringPassOptions &options) { this->splitVectorTransfersTo = options.splitVectorTransfersTo; this->lowerVectorTransposeToAVX2 = options.lowerVectorTransposeToAVX2; + this->enableQuantizedMatmulReassociation = + options.enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -77,6 +79,27 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { .setVectorTransformsOptions(vectorContractLowering) .setVectorMultiReductionLowering(vectorMultiReductionLowering) .setVectorTransferSplit(vectorTransferSplit); + + { + if (enableQuantizedMatmulReassociation) { + // Special-case vector.contract codegen paths. This needs to happen + // just before the generic vector ops lowerings. + RewritePatternSet patterns(ctx); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + populateVectorContractCustomKernelsPatterns(target, patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After custom kernel lowering for " + "vector.contract ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // Lower high level vector operations like contract or multidim reduce ops // to lower level vector ops. { @@ -158,6 +181,23 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { llvm::dbgs() << "\n\n"; }); + // Break down subbyte `arith.extui` ops + { + if (enableQuantizedMatmulReassociation) { + RewritePatternSet patterns(&getContext()); + populateLLVMCPUBreakDownSubbyteExtendPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After breaking down subbyte extend ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // 'vector.shape_cast' are very expensive operations that are even generated // by some of the lowerings above (e.g., transpose lowering). There are // chances to cancel them out if they are not lowered too early so we lower diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index b563dbac0cd4..09467bc80e8a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -86,6 +86,13 @@ static llvm::cl::opt clInstrumentMemoryAccesses{ "instrumentation is enabled."), llvm::cl::init(false)}; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-llvmcpu-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables LLVMCPU codegen optimizations specific to reassociated " + "quantized matmuls (experimental)."), + llvm::cl::init(false)); + // MLIR file containing a top-level module that specifies the transformations to // apply to form dispatch regions. // Defined externally in KernelDispatch.cpp to control the codegen pass @@ -226,6 +233,19 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( << index << "-th tile size set"; } } + // if (!clEnableQuantizedMatmulReassociation) { + // SmallVector thirdLevelTileSizes; + // std::tie(thirdLevelTileSizes, std::ignore) = + // tilingConfig.getVectorReductionSizes(); + // for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { + // if (tileSize != 0 && pLoopsSet.contains(index)) { + // return op->emitOpError("expected only reduction dims to be set in " + // "the third tiling " + // "level, got ") + // << index << "-th tile size set"; + // } + // } + // } } // Verify interchange @@ -471,7 +491,9 @@ void addMultiTilingExpertPassPipeline( // Run SplitReductionPass before the final reduction Fuse pass, because // SplitReductionPass takes care of banked-tiling. nestedModulePM.addNestedPass( - createLLVMCPUSplitReductionPass(clEnableReassociateFpReductions)); + createLLVMCPUSplitReductionPass( + clEnableReassociateFpReductions, + clEnableQuantizedMatmulReassociation)); nestedModulePM.addNestedPass(createLLVMCPUTilePass(i)); continue; } @@ -508,11 +530,17 @@ void addMultiTilingExpertPassPipeline( // Run IREE specific passes before vector lowering expert. nestedModulePM.addNestedPass( createRemoveSingleIterationLoopPass()); + if (clEnableQuantizedMatmulReassociation) { + nestedModulePM.addNestedPass( + createLLVMCPUFoldVectorContractUnitDimsPass()); + } { LLVMCPUVectorLoweringPassOptions options; options.lowerVectorTransposeToAVX2 = lowerToAVX2; options.splitVectorTransfersTo = "linalg-copy"; + options.enableQuantizedMatmulReassociation = + clEnableQuantizedMatmulReassociation; nestedModulePM.addNestedPass( createLLVMCPUVectorLoweringPass(options)); } @@ -743,6 +771,9 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) { passManager.addNestedPass(arith::createArithExpandOpsPass()); passManager.addNestedPass(memref::createExpandOpsPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + if (clEnableQuantizedMatmulReassociation) { + passManager.addPass(createLLVMCPUFoldMemRefAliasOpsPass()); + } passManager.addPass(createEmulateNarrowTypePass()); passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 47dad29749e1..4527ec5a07a8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -20,6 +20,10 @@ namespace iree_compiler { class TilingConfig; +// Pass to breakdown subbyte extui +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass(); + /// Performs the final conversion to LLVM dialect. std::unique_ptr> createConvertToLLVMPass(bool reassociateFpReordering = false); @@ -55,8 +59,9 @@ createLLVMCPUMmt4dVectorLoweringPass(); std::unique_ptr> createLLVMCPUPeelPass(); /// Pass to perform SplitReduction transformations of `LinalgOp`s. -std::unique_ptr> -createLLVMCPUSplitReductionPass(bool enableReassociateFpReductions = false); +std::unique_ptr> createLLVMCPUSplitReductionPass( + bool enableReassociateFpReductions = false, + bool enableQuantizedMatmulReassociation = false); /// Synchronizes LLVM linkage with MLIR symbol visibility. std::unique_ptr> @@ -82,6 +87,7 @@ std::unique_ptr> createLLVMCPUUnfuseFMAOpsPass(); struct LLVMCPUVectorLoweringPassOptions { std::string splitVectorTransfersTo = ""; bool lowerVectorTransposeToAVX2 = false; + bool enableQuantizedMatmulReassociation = false; }; std::unique_ptr> createLLVMCPUVectorLoweringPass(); std::unique_ptr> createLLVMCPUVectorLoweringPass( @@ -96,6 +102,11 @@ createVectorContractCustomKernelsPass(); std::unique_ptr> createVerifyLinalgTransformLegalityPass(); +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass(); + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass(); + //------------------------------------------------------------------------------ // LLVMCPU Codegen specific patterns. //------------------------------------------------------------------------------ @@ -108,6 +119,11 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void populateVectorContractCustomKernelsPatterns( IREE::HAL::ExecutableTargetAttr target, RewritePatternSet &patterns); +void populateLLVMCPUBreakDownSubbyteExtendPatterns(RewritePatternSet &patterns); + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context); + //----------------------------------------------------------------------------// // LLVMCPU backend Pass Pipelines. //----------------------------------------------------------------------------// diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 69fc43ffc07d..0d4c95f70c1a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -45,6 +45,11 @@ def LLVMCPUAssignImportOrdinals : let constructor = "mlir::iree_compiler::createLLVMCPUAssignImportOrdinalsPass()"; } +def LLVMCPUBreakDownSubbyteExtend : Pass<"iree-llvmcpu-breakdown-subbyte-extend", "func::FuncOp"> { + let summary = "Pass to break down subbyte extui ops."; + let constructor = "mlir::iree_compiler::createLLVMCPUBreakDownSubbyteExtendPass()"; +} + def LLVMCPUCheckIRBeforeLLVMConversion : Pass<"iree-llvmcpu-check-ir-before-llvm-conversion", "ModuleOp"> { let summary = "Checks CPU backend specific IR constraints (like no allocas)"; @@ -58,6 +63,20 @@ def LLVMCPUEmitVectorizationRemarks : "mlir::iree_compiler::createLLVMCPUEmitVectorizationRemarksPass()"; } +def LLVMCPUFoldVectorContractUnitDims : + Pass<"iree-llvmcpu-fold-vector-contract-unit-dims", "func::FuncOp"> { + let summary = "Fold unit dims on vector.contract ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldVectorContractUnitDimsPass()"; +} + +def LLVMCPUFoldMemRefAliasOps : + Pass<"iree-llvmcpu-fold-memref-alias-ops", ""> { + let summary = "Fold combinations of memref ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldMemRefAliasOpsPass()"; +} + def LLVMCPULinkExecutables : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; @@ -107,6 +126,9 @@ def LLVMCPUSplitReduction : Pass<"iree-llvmcpu-split-reduction", "func::FuncOp"> Option<"enableFpReductionReordering", "enable-fp-reduction-reordering", "bool", /*default=*/"false", "Flag to enable reduction reordering on floating points.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", + "bool", /*default=*/"false", + "Flag to enable optimizations for reassociated quantized matmuls.">, ]; } @@ -166,6 +188,9 @@ def LLVMCPUVectorLowering : Option<"lowerVectorTransposeToAVX2", "lower-vector-transpose-to-avx2", "bool", /*default=*/"false", "Add specific transpose to avx2 lowering patterns.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", + "Add specific patterns for optimizing reassociated quantized matmuls.">, ]; let constructor = "mlir::iree_compiler::createLLVMCPUVectorLoweringPass()"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index a4730ee1483a..fe6fd081e5ef 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -25,6 +25,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-vector-contract-custom-kernels" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { namespace iree_compiler { @@ -86,6 +90,73 @@ static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { return true; } +static bool isVectorTimesMatrixTransposed(vector::ContractionOp contractionOp, + int64_t splitSize) { + // Check that the reduction is additive. + if (contractionOp.getKind() != vector::CombiningKind::ADD) { + return false; + } + // Check that there are 1 parallel and 1 reduction iterators. + unsigned numIters = splitSize ? 3 : 2; + auto iteratorTypes = contractionOp.getIteratorTypes().getValue(); + if (iteratorTypes.size() != numIters) { + return false; + } + SmallVector parallelIterators; + SmallVector reductionIterators; + for (int i = 0; i < numIters; i++) { + if (vector::isParallelIterator(iteratorTypes[i])) { + parallelIterators.push_back(i); + } else if (vector::isReductionIterator(iteratorTypes[i])) { + reductionIterators.push_back(i); + } else { + return false; + } + } + if (parallelIterators.size() != numIters - 1 || + reductionIterators.size() != 1) { + return false; + } + // Give the found iterators some idiomatic names. + const int NIter = parallelIterators[0]; + const int KIter = reductionIterators[0]; + const int SplitIter = splitSize ? parallelIterators[1] : 0; + // Check that there are 3 indexing maps. + auto indexingMaps = contractionOp.getIndexingMapsArray(); + if (indexingMaps.size() != 3) { + return false; + } + // Check that the indexing maps have the expected form. + SmallVector> expectedMapResults; + if (splitSize) { + SmallVector> res = { + {KIter, SplitIter}, {NIter, KIter, SplitIter}, {NIter, SplitIter}}; + expectedMapResults = res; + numIters = 3; + } else { + SmallVector> res = {{KIter}, {NIter, KIter}, {NIter}}; + expectedMapResults = res; + numIters = 2; + } + for (int m = 0; m < 3; ++m) { + auto map = indexingMaps[m]; + auto expectedResults = expectedMapResults[m]; + if (map.getNumDims() != numIters || + map.getNumResults() != expectedResults.size()) { + return false; + } + for (int r = 0; r < expectedResults.size(); ++r) { + int actualMapResult = + map.getResults()[r].cast().getPosition(); + if (actualMapResult != expectedMapResults[m][r]) { + return false; + } + } + } + LDBG("passed isVectorTimesMatrixTransposed"); + return true; +} + // Returns true if `contractionOp` is of the form // matrix * transposed_matrix // where matrix is a vector<{mSize}x{kSize}xType>, and @@ -132,6 +203,31 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, return false; } +static bool matchVMT(vector::ContractionOp contractionOp, int64_t mSize, + int64_t kSize, int64_t nSize, int splitSize, + bool *transpose = nullptr) { + if (mSize != 1) { + return false; + } + if (!isVectorTimesMatrixTransposed(contractionOp, splitSize)) { + return false; + } + VectorType lhsType = llvm::cast(contractionOp.getLhs().getType()); + VectorType rhsType = llvm::cast(contractionOp.getRhs().getType()); + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + if (splitSize && (lhsShape[1] != splitSize || rhsShape[2] != splitSize)) { + return false; + } + if (lhsShape[0] != kSize || rhsShape[1] != kSize) { + return false; + } + if (rhsShape[0] == nSize) { + return true; + } + return false; +} + // `promotedResult` is required to be a Vector. // If its VectorType does not have `promotedType` as its element type, or // the operand to the type-promotion op is not `unpromotedType` returns a null @@ -143,8 +239,9 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, // Note that this only looks at the immediately defining operation, so we likely // want to have earlier passes that sink widening operations as far down as // possible, which is probably just good regardless. -static Value getUnpromotedInput(Type unpromotedType, Type promotedType, - Value promotedResult) { +static Value getUnpromotedInput(PatternRewriter &rewriter, Type unpromotedType, + Type promotedType, Value promotedResult, + bool promoteSmallTypes = false) { VectorType promotedResultVectorType = llvm::cast(promotedResult.getType()); if (promotedResultVectorType.getElementType() != promotedType) { @@ -156,13 +253,29 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // TODO: handle promotion of floating point types. Not doing it for now as // it wouldn't be exercised. auto extSIOp = promotedResult.getDefiningOp(); - if (!extSIOp) { + auto extUIOp = promotedResult.getDefiningOp(); + if (!extSIOp && !extUIOp) { return nullptr; } - Value extInput = extSIOp.getIn(); + Value extInput = extSIOp ? extSIOp.getIn() : extUIOp.getIn(); if (llvm::cast(extInput.getType()).getElementType() != unpromotedType) { - return nullptr; + if (promoteSmallTypes) { + VectorType unpromotedVectorType = + VectorType::get(llvm::cast(extInput.getType()).getShape(), + unpromotedType); + return extSIOp + ? rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult() + : rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult(); + } else { + return nullptr; + } } return extInput; } @@ -170,12 +283,28 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // Helper to create a 1D, contiguous slice of a 1D vector. static Value extract1DSlice(PatternRewriter &rewriter, Location loc, VectorType dstVecType, Value input, int position) { - assert(input.getType().cast().getRank() == 1); assert(dstVecType.getRank() == 1); - std::array offsets{position}; - std::array strides{1}; - return rewriter.create( - loc, input, offsets, dstVecType.getShape(), strides); + if (input.getType().cast().getRank() == 1) { + SmallVector offsets({position}); + SmallVector strides({1}); + SmallVector sizes(dstVecType.getShape()); + return rewriter.create(loc, input, offsets, + sizes, strides); + } else { + SmallVector inputShape( + llvm::cast(input.getType()).getShape()); + assert(inputShape.back() == dstVecType.getNumElements()); + std::reverse(inputShape.begin(), inputShape.end()); + int currentPos = position; + SmallVector indices; + for (auto size : inputShape) { + indices.push_back(currentPos % size); + currentPos = currentPos / size; + } + std::reverse(indices.begin(), indices.end()); + return rewriter.create( + loc, input, SmallVector(indices.begin(), indices.end() - 1)); + } } // Helper to extract an element of a 1D vector. @@ -189,8 +318,12 @@ static Value extract(PatternRewriter &rewriter, Location loc, Value input, } // Helper to flatten a N-dimensional vector to a 1D vector. -static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { +static Value flattenImperfectSize(PatternRewriter &rewriter, Location loc, + Value vector, VectorType regVectorType) { VectorType inputVecType = llvm::cast(vector.getType()); + if (regVectorType.getNumElements() == inputVecType.getShape().back()) { + return vector; + } VectorType dstType = VectorType::get(inputVecType.getNumElements(), inputVecType.getElementType()); return rewriter.create(loc, dstType, vector); @@ -207,20 +340,31 @@ static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { // (2) Be explicit about the size of the vectors involved in the kernel's // "calling convention". struct MMTKernel { - enum class ScalarType : int8_t { None, I8, I32, F32 }; + enum class ScalarType : int8_t { None, I4, I8, I16, I32, F32 }; // Element type of the LHS vectors. ScalarType lhsType = ScalarType::None; // Element type of the RHS vectors. ScalarType rhsType = ScalarType::None; // Element type of the Accumulator and output vectors. ScalarType accType = ScalarType::None; + // Optional user defined constrained codes for input and output registers. + // This is useful when the constraint code is not the same for all operands. + std::optional> lhsCode = std::nullopt; + std::optional> rhsCode = std::nullopt; + std::optional> accCode = std::nullopt; + // This flag indicates whether or not to promote inputs that have a smaller + // bitwidth than lhsType, rhsType, or accType, to the appropriate bitwidth + bool promoteSmallTypes = false; // Number of rows of the LHS and Accumulator tile. - int8_t m0 = 0; + int16_t m0 = 0; // Reduction dimension, i.e. number of columns of the LHS. - int8_t k0 = 0; + int16_t k0 = 0; // Number of rows of the RHS (note that the operation being targeted, MMT, // is matrix multiplication with a *transposed* RHS) - int8_t n0 = 0; + int16_t n0 = 0; + // Size of the added parallel dimension when the vector.contract op has been + // split with splitReduction + int16_t split0 = 0; // Number of LHS elements in the type of register to be used for the LHS. // This is > 1 if SIMD registers are to be used. // Note: LHS/RHS/Accumulator may use registers of different sizes. @@ -236,6 +380,8 @@ struct MMTKernel { int8_t rhsRegs = 0; // Number of registers needed to hold the Accumulator. int8_t accRegs = 0; + // Indicates whether to use Intel or AT&T syntax + bool useIntel = false; // If not null, points to the inline asm code template for this kernel. // Register operands for the LHS, RHS and Accumulator are to be referenced as // $(lhs:), $(rhs:), $(acc:) respectively, where i is a decimal @@ -250,9 +396,15 @@ struct MMTKernel { const char *asmClobbers = nullptr; void validate() const { - assert(m0 * k0 == lhsRegSize * lhsRegs); // number of elements of LHS - assert(n0 * k0 == rhsRegSize * rhsRegs); // number of elements of RHS - assert(m0 * n0 == accRegSize * accRegs); // number of elements of Accum + assert(m0 * k0 == lhsRegSize * lhsRegs || + m0 * k0 * split0 == + lhsRegSize * lhsRegs); // number of elements of LHS + assert(n0 * k0 == rhsRegSize * rhsRegs || + n0 * k0 * split0 == + rhsRegSize * rhsRegs); // number of elements of RHS + assert(m0 * n0 == accRegSize * accRegs || + m0 * n0 * split0 == + accRegSize * accRegs); // number of elements of Accum assert(lhsType != ScalarType::None); assert(rhsType != ScalarType::None); assert(accType != ScalarType::None); @@ -674,13 +826,75 @@ MMTKernel MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm() { return kernel; } +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpdpwssd $(acc:0), $(rhs:0), $(lhs:0) + vpdpwssd $(acc:1), $(rhs:1), $(lhs:0) + vpdpwssd $(acc:2), $(rhs:2), $(lhs:0) + vpdpwssd $(acc:3), $(rhs:3), $(lhs:0) + )ASM"; + kernel.asmClobbers = ""; + return kernel; +} + +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpmaddwd zmm17, $(rhs:0), $(lhs:0) + vpmaddwd zmm18, $(rhs:1), $(lhs:0) + vpmaddwd zmm19, $(rhs:2), $(lhs:0) + vpmaddwd zmm20, $(rhs:3), $(lhs:0) + vpaddw $(acc:0), $(acc:0), zmm17 + vpaddw $(acc:1), $(acc:1), zmm18 + vpaddw $(acc:2), $(acc:2), zmm19 + vpaddw $(acc:3), $(acc:3), zmm20 + )ASM"; + kernel.asmClobbers = "zmm17,zmm18,zmm19,zmm20"; + return kernel; +} + // Constructs the mlir::Type corresponding to a scalar type. Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) { switch (t) { case MMTKernel::ScalarType::None: break; + case MMTKernel::ScalarType::I4: + return IntegerType::get(context, 4, IntegerType::Signless); case MMTKernel::ScalarType::I8: return IntegerType::get(context, 8, IntegerType::Signless); + case MMTKernel::ScalarType::I16: + return IntegerType::get(context, 16, IntegerType::Signless); case MMTKernel::ScalarType::I32: return IntegerType::get(context, 32, IntegerType::Signless); case MMTKernel::ScalarType::F32: @@ -705,7 +919,7 @@ class MMTKernelGenerator { ArrayRef acc) { validateOperands(lhs, rhs, acc); if (kernel.asmImpl) { - return generateAsm(rewriter, loc, lhs, rhs, acc); + return generateAsm(rewriter, loc, lhs, rhs, acc, kernel.useIntel); } // In the future we may have alternate generator paths, e.g. 1D intrinsics // or other asm paths with a different interface, e.g. handling also @@ -755,10 +969,17 @@ class MMTKernelGenerator { validate(acc, kernel.accRegs, getAccRegVectorType()); } // Helper for generateAsmCodeAndConstraints - std::string getConstraintCode() const { + std::string + getConstraintCode(std::optional kernelConstraintCode) const { + if (kernelConstraintCode) { + return std::string(*kernelConstraintCode); + } if (isAArch64(target)) { return "w"; } + if (isX86(target)) { + return "v"; + } assert(false && "what constraint code to use on this arch?"); return {}; } @@ -820,31 +1041,39 @@ class MMTKernelGenerator { // processedIdx is the index of a register in the processed asm. // Example: $5 => processedIdx == 5 int processedIdx = 0; - auto processOperands = [&](Constraints::Kind constraintKind, - const char *name, int count) { - const std::string &constraintCode = getConstraintCode(); - // unprocessedIdx is the index of a register in the unprocessed asm. - // Example: $(lhs:1) => unprocessedIdx == 1 - for (int unprocessedIdx = 0; unprocessedIdx < count; - ++unprocessedIdx, ++processedIdx) { - constraints.add(constraintKind, constraintCode); - // Perform the code replacement for the operand. - // Example: $(lhs:1) => $5 - replaceAllSubstrsInPlace( - code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), - llvm::formatv("${0}", processedIdx)); - } - }; - processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs); - processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs); - processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs); + auto processOperands = + [&](Constraints::Kind constraintKind, const char *name, int count, + std::optional> kernelCodes) { + const std::string &constraintCode = getConstraintCode(std::nullopt); + // unprocessedIdx is the index of a register in the unprocessed asm. + // Example: $(lhs:1) => unprocessedIdx == 1 + for (int unprocessedIdx = 0; unprocessedIdx < count; + ++unprocessedIdx, ++processedIdx) { + if (kernelCodes) { + constraints.add(constraintKind, (*kernelCodes)[unprocessedIdx]); + } else { + constraints.add(constraintKind, constraintCode); + } + // Perform the code replacement for the operand. + // Example: $(lhs:1) => $5 + replaceAllSubstrsInPlace( + code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), + llvm::formatv("${0}", processedIdx)); + } + }; + processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs, + kernel.accCode); + processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs, + kernel.lhsCode); + processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs, + kernel.rhsCode); constraints.setClobbers(kernel.asmClobbers); constraintsString = constraints.toString(); } // Helper for generate(). Implements the asm path. SmallVector generateAsm(PatternRewriter &rewriter, Location loc, ArrayRef lhs, ArrayRef rhs, - ArrayRef acc) { + ArrayRef acc, bool useIntel) { SmallVector inputs; // First the input operands. Then the input-output operands, which, as far // as input constraints are concerned, are *tied* inputs, i.e. refer to @@ -864,9 +1093,13 @@ class MMTKernelGenerator { SmallVector outputOperandTypes( llvm::map_range(acc, [](Value v) { return v.getType(); })); auto returnType = - LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); + outputOperandTypes.size() == 1 + ? outputOperandTypes[0] + : LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); auto dialectAttr = - LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); + useIntel + ? LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_Intel) + : LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); std::string code; std::string constraints; generateAsmCodeAndConstraints(code, constraints); @@ -876,10 +1109,14 @@ class MMTKernelGenerator { /*operand_attrs=*/ArrayAttr()); // Extract result vectors from the asm op. SmallVector resVec; - for (int i = 0; i < kernel.accRegs; ++i) { - SmallVector position = {i}; - resVec.push_back( - rewriter.create(loc, asmOp.getRes(), position)); + if (outputOperandTypes.size() == 1) { + resVec.push_back(asmOp.getRes()); + } else { + for (int i = 0; i < kernel.accRegs; ++i) { + SmallVector position = {i}; + resVec.push_back(rewriter.create( + loc, asmOp.getRes(), position)); + } } return resVec; } @@ -914,7 +1151,9 @@ class MMTCustomKernelPattern : public OpRewritePattern { // Check if `contractionOp` matches, and obtain the (un-promoted) input // LHS and RHS vectors. bool transposeKernel = false; - if (!matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, + if (!matchVMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, kernel.split0, + &transposeKernel) && + !matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, &transposeKernel)) { return failure(); } @@ -929,9 +1168,11 @@ class MMTCustomKernelPattern : public OpRewritePattern { return failure(); } Value unpromotedLhs = - getUnpromotedInput(lhsElemType, accElemType, contractionOp.getLhs()); + getUnpromotedInput(rewriter, lhsElemType, accElemType, + contractionOp.getLhs(), kernel.promoteSmallTypes); Value unpromotedRhs = - getUnpromotedInput(rhsElemType, accElemType, contractionOp.getRhs()); + getUnpromotedInput(rewriter, rhsElemType, accElemType, + contractionOp.getRhs(), kernel.promoteSmallTypes); if (!unpromotedLhs || !unpromotedRhs) { return failure(); } @@ -953,9 +1194,23 @@ class MMTCustomKernelPattern : public OpRewritePattern { // `contractionOp` matches, start rewriting it. Location loc = contractionOp.getLoc(); // Flatten the inputs to 1D vectors. - Value flatLhs = flatten(rewriter, loc, unpromotedLhs); - Value flatRhs = flatten(rewriter, loc, unpromotedRhs); - Value flatAcc = flatten(rewriter, loc, contractionOp.getAcc()); + VectorType lhsRegVectorType = generator.getLhsRegVectorType(); + VectorType rhsRegVectorType = generator.getRhsRegVectorType(); + VectorType accRegVectorType = generator.getAccRegVectorType(); + Value lhs, rhs; + if (transposeKernel) { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, rhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, lhsRegVectorType); + } else { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, lhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, rhsRegVectorType); + } + Value acc = flattenImperfectSize(rewriter, loc, contractionOp.getAcc(), + accRegVectorType); // Slice into SIMD-register-sized 1D input vectors ready to feed to the // target SIMD instructions. auto sliceIntoRegVectors = [&](int regsCount, VectorType regVectorType, @@ -968,17 +1223,14 @@ class MMTCustomKernelPattern : public OpRewritePattern { } return regVectors; }; - VectorType lhsRegVectorType = generator.getLhsRegVectorType(); - VectorType rhsRegVectorType = generator.getRhsRegVectorType(); - VectorType accRegVectorType = generator.getAccRegVectorType(); - Value flatLhsForKernel = transposeKernel ? flatRhs : flatLhs; - Value flatRhsForKernel = transposeKernel ? flatLhs : flatRhs; + Value lhsForKernel = transposeKernel ? rhs : lhs; + Value rhsForKernel = transposeKernel ? lhs : rhs; SmallVector lhsRegVectors = - sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, flatLhsForKernel); + sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, lhsForKernel); SmallVector rhsRegVectors = - sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, flatRhsForKernel); + sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, rhsForKernel); SmallVector accRegVectors = - sliceIntoRegVectors(kernel.accRegs, accRegVectorType, flatAcc); + sliceIntoRegVectors(kernel.accRegs, accRegVectorType, acc); // Generate the kernel! SmallVector resRegVectors = generator.generate( rewriter, loc, lhsRegVectors, rhsRegVectors, accRegVectors); @@ -1037,8 +1289,8 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics return failure(); } - Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); - Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); + Value inLhs = getUnpromotedInput(rewriter, I8Type, I32Type, lhs); + Value inRhs = getUnpromotedInput(rewriter, I8Type, I32Type, rhs); if (!inLhs || !inRhs) return failure(); @@ -1171,6 +1423,15 @@ void populateVectorContractCustomKernelsPatterns( patterns.add( context, MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm()); } + } else if (isX86(target)) { + if (hasFeature(target, "+avx512vnni")) { + patterns.add( + context, + MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm()); + } else if (hasFeature(target, "+avx512bw")) { + patterns.add( + context, MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm()); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 5ee617cddac1..50eacb872ecb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -102,9 +102,13 @@ def FormScalarDispatches : } def FuseDequantizationMatmul: - Pass<"iree-flow-fuse-dequantization-matmul", ""> { + InterfacePass<"iree-flow-fuse-dequantization-matmul", "mlir::FunctionOpInterface"> { let summary = "Fuse dequantization and matmul linalg.generic ops"; let constructor = "mlir::iree_compiler::IREE::Flow::createFuseDequantizationMatmulPass()"; + let options = [ + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", "Allow reassociation of quantized matmuls (experimental)">, + ]; } def CollapseDimensions :