diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 74c2aedcd5e5..c6c677993e16 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -119,6 +119,15 @@ LogicalResult createTorchPermuteOp(OpBinder binder, SmallVector permuteDims, Value &permuted); +// Checks the validity of pooling parameters and stores them in the respective +// vector. +LogicalResult checkAndGetOnnxPoolingOpParameters( + OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype, + std::string autoPad, int64_t spatialRank, Value &input, + SmallVectorImpl &kernelSizeInts, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index d8517fbd156d..41dbbbac0137 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -456,24 +456,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( patterns.onOp( "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - std::string autoPad; - SmallVector dilations; - if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) - return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - } - Torch::ValueTensorType resultType; Value operand; - bool ceilMode, countIncludePad; + int64_t ceilMode, countIncludePad; + std::string autoPad; if (binder.tensorOperand(operand) || - binder.s64BoolAttr(ceilMode, "ceil_mode", false) || - binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || + binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || + binder.s64IntegerAttr(countIncludePad, "count_include_pad", 0) || + binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") || binder.tensorResultType(resultType)) - return failure(); + return rewriter.notifyMatchFailure( + binder.op, "operand/ceil_mode/count_include_pad/auto_pad/" + "resultType bind failure"); + // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) @@ -481,82 +476,27 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; - SmallVector kernel, padding, strides; - if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { - return failure(); - } - if (kernel.size() != rank - 2) { - return rewriter.notifyMatchFailure( - binder.op, "kernel list size does not match the number of axes"); - } - SmallVector defaultPadding(2 * (rank - 2), 0); - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != 2 * (rank - 2)) { - return rewriter.notifyMatchFailure( - binder.op, - "padding list size does not match twice the number of axes"); - } - if (binder.s64IntegerArrayAttr( - strides, "strides", llvm::SmallVector(rank - 2, 1))) { - return failure(); - } - if (strides.size() != 1 && strides.size() != rank - 2) { - return rewriter.notifyMatchFailure( - binder.op, "strides list size does not match the number of axes"); - } - - SmallVector cstKernel, cstPadding, cstStridesDilations; - for (int64_t i : kernel) { - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] - // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all - // axes x. - int64_t paddingSizeHalf = padding.size() / 2; - for (int64_t i = 0; i < paddingSizeHalf; ++i) { - // Check if onnx padding attribute is symmetric. - if (padding[i] != padding[i + paddingSizeHalf]) - return rewriter.notifyMatchFailure( - binder.op, "onnx padding attribute is not symmetric"); - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); - } - for (int64_t i : strides) { - cstStridesDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } + SmallVector kernel, padding, strides, dilations, + stridesDilations; + if (failed(checkAndGetOnnxPoolingOpParameters( + binder, rewriter, resultType.getDtype(), autoPad, + /*spatialRank=*/rank - 2, + /*input=*/operand, kernel, strides, padding, dilations))) + return rewriter.notifyMatchFailure(binder.op, + "invalid pooling parameters"); - // No dilations attribute in pytorch avgpool op, so use this trick to - // encode dilation into strides. Then in the following torchtolinalg - // lowering, decode strides into strides + dilation. + // Since the PyTorch AvgPool op does not contain the `dilation` arg, + // hence we use the trick of encoding dilation into strides. Then, + // during the torch->linalg lowering of the `AvgPool` op we decode the + // `strides` arg into strides values followed by dilation like: // [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] - if (binder.s64IntegerArrayAttr( - dilations, "dilations", - llvm::SmallVector(rank - 2, 1))) { - return failure(); - } - for (auto dilation : dilations) { - cstStridesDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); - } + stridesDilations = strides; + stridesDilations.append(dilations); - Value kernelSizeList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstKernel); - Value paddingList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstPadding); + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesDilationsList = - rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - cstStridesDilations); + createConstantIntList(binder, rewriter, stridesDilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); Value cstCountIncludePad = rewriter.create( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 13f555c146b4..258bb3dc719c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1124,138 +1124,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - std::string autoPad; - if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) - return rewriter.notifyMatchFailure(binder.op, - "auto_pad bind failure"); - Torch::ValueTensorType resultTypeOut; Value operand; int64_t ceilMode, storageOrder; - // TODO: Add support for indices output and storage_order + std::string autoPad; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || binder.s64IntegerAttr(storageOrder, "storage_order", 0) || + binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") || binder.tensorResultTypeAtIndex(resultTypeOut, 0)) return rewriter.notifyMatchFailure( - binder.op, - "operand/ceil_mode/storage_order/resultType bind failure"); + binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType " + "bind failure"); + // TODO: Add support for storage_order if (storageOrder != 0) return rewriter.notifyMatchFailure( binder.op, "storage_order setting is not supported."); + // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); - int64_t rank = *maybeRank; - int64_t spatial = rank - 2; + unsigned rank = *maybeRank; - SmallVector kernel, padding, strides, dilations; - if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + SmallVector kernel, padding, strides, dilations, + stridesDilations; + if (failed(checkAndGetOnnxPoolingOpParameters( + binder, rewriter, resultTypeOut.getDtype(), autoPad, + /*spatialRank=*/rank - 2, + /*input=*/operand, kernel, strides, padding, dilations))) return rewriter.notifyMatchFailure(binder.op, - "kernel_shape bind failure"); - if (kernel.size() != static_cast(spatial)) - return rewriter.notifyMatchFailure( - binder.op, "kernel list size does not match the number of axes"); - if (binder.s64IntegerArrayAttr(padding, "pads", {})) - return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); - if (!padding.empty() && - padding.size() != static_cast(2 * spatial)) - return rewriter.notifyMatchFailure( - binder.op, "padding list must contain (begin,end) pair for each " - "spatial axis"); - if (binder.s64IntegerArrayAttr(strides, "strides", {})) - return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); - if (!strides.empty() && strides.size() != static_cast(spatial)) - return rewriter.notifyMatchFailure( - binder.op, "strides list size does not match the number of axes"); - if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) - return rewriter.notifyMatchFailure(binder.op, - "dilations bind failure"); - - // set default padding - if (padding.empty()) - padding.resize(spatial, 0); - if (strides.empty()) - strides.resize(spatial, 1); - if (dilations.empty()) - dilations.resize(spatial, 1); - - auto inputTensorType = cast(operand.getType()); - - // Padding for the beginning and ending along each spatial axis, it can - // take any value greater than or equal to 0. The value represent the - // number of pixels added to the beginning and end part of the - // corresponding axis. pads format should be as follow [x1_begin, - // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added - // at the beginning of axis i and xi_end, the number of pixels added at - // the end of axis i. - if (autoPad != "NOTSET" && autoPad != "VALID") { - const bool isSameLower = autoPad == "SAME_LOWER"; - ArrayRef inputShape = inputTensorType.getSizes(); - padding.resize_for_overwrite(2 * spatial); - for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { - const int64_t dilatedKernelSize = - dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; - int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / - strides[dimIdx] - - 1) * - strides[dimIdx] + - dilatedKernelSize - inputShape[dimIdx + 2]; - totalPad = totalPad >= 0 ? totalPad : 0; - padding[dimIdx] = - isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); - padding[spatial + dimIdx] = totalPad - padding[dimIdx]; - } - } - - // If the padding is symmetric we can push the padding operation to the - // torch operator. - if (padding.size() == static_cast(2 * spatial)) { - bool equal = true; - for (int i = 0; i < spatial; ++i) { - equal = equal && (padding[i] == padding[i + spatial]); - } - if (equal) - padding.resize(spatial); - } - - // Torch pool operators require equal padding on each size of each - // dimension so we materialize the padding behavior explicitly and set - // the padding to 0. - if (padding.size() == static_cast(2 * spatial)) { - auto operandTy = cast(operand.getType()); - llvm::SmallVector shuffledPadding(spatial * 2); - llvm::SmallVector paddedShape(operandTy.getSizes()); - for (int i = 0; i < spatial; ++i) { - paddedShape[i + 2] += padding[i] + padding[i + spatial]; - shuffledPadding[2 * i] = padding[spatial - i - 1]; - shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1]; - } - - Value shuffledPaddingList = - createConstantIntList(binder, rewriter, shuffledPadding); - Value zero; - if (isa(resultTypeOut.getDtype())) { - zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr( - std::numeric_limits::lowest())); - } else if (isa(resultTypeOut.getDtype())) { - zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - std::numeric_limits::lowest())); - } - - auto paddedInputTy = rewriter.getType( - paddedShape, operandTy.getDtype()); - operand = rewriter.create( - binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, - zero); - padding.clear(); - padding.resize(spatial, 0); - } + "invalid pooling parameters"); Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); Value paddingList = createConstantIntList(binder, rewriter, padding); diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 5361089d69d1..e873a8137687 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -142,3 +142,117 @@ Value mlir::torch::onnx_c::createActivationByName(ImplicitLocOpBuilder &b, return b.create(input.getType(), input); llvm_unreachable("Unsupported activation function"); } + +// Checks the validity of pooling parameters and stores them in the respective +// vector. +LogicalResult mlir::torch::onnx_c::checkAndGetOnnxPoolingOpParameters( + OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype, + std::string autoPad, int64_t spatialRank, Value &input, + SmallVectorImpl &kernelSizeInts, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts) { + SmallVector kernel, padding, strides, dilations; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure"); + if (kernel.size() != static_cast(spatialRank)) + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatialRank)) + return rewriter.notifyMatchFailure( + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && strides.size() != static_cast(spatialRank)) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + + // set default values for padding, strides, and dilations. + if (padding.empty()) + padding.resize(spatialRank, 0); + if (strides.empty()) + strides.resize(spatialRank, 1); + if (dilations.empty()) + dilations.resize(spatialRank, 1); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + auto inputTensorType = cast(input.getType()); + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatialRank); + for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = + ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatialRank + dimIdx] = totalPad - padding[dimIdx]; + } + } + + // If the padding is symmetric we can push the padding operation to the + // torch operator. + if (padding.size() == static_cast(2 * spatialRank)) { + bool equal = true; + for (int i = 0; i < spatialRank; ++i) { + equal = equal && (padding[i] == padding[i + spatialRank]); + } + if (equal) + padding.resize(spatialRank); + } + + // Torch pool operators require equal padding on each size of each + // dimension so we materialize the padding behavior explicitly and set + // the padding to 0. + if (padding.size() == static_cast(2 * spatialRank)) { + llvm::SmallVector shuffledPadding(spatialRank * 2); + llvm::SmallVector paddedShape(inputTensorType.getSizes()); + for (int i = 0; i < spatialRank; ++i) { + paddedShape[i + 2] += padding[i] + padding[i + spatialRank]; + shuffledPadding[2 * i] = padding[spatialRank - i - 1]; + shuffledPadding[2 * i + 1] = padding[2 * spatialRank - i - 1]; + } + + Value shuffledPaddingList = + createConstantIntList(binder, rewriter, shuffledPadding); + Value zero; + if (isa(resultDtype)) { + zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(std::numeric_limits::lowest())); + } else if (isa(resultDtype)) { + zero = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(std::numeric_limits::lowest())); + } + + auto paddedInputTy = rewriter.getType( + paddedShape, inputTensorType.getDtype()); + input = rewriter.create( + binder.getLoc(), paddedInputTy, input, shuffledPaddingList, zero); + padding.clear(); + padding.resize(spatialRank, 0); + } + + kernelSizeInts = kernel; + strideInts = strides; + paddingInts = padding; + dilationInts = dilations; + return success(); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 5e62efa00cf7..90640ccf0ebc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -984,6 +984,18 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32> // ----- +// CHECK-LABEL: @test_averagepool_with_asymmetric_padding +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,1024,6,6],f32> +// CHECK: %[[PADDED_INPUT:.*]] = torch.aten.constant_pad_nd %[[ARG]], {{.*}} : !torch.vtensor<[1,1024,6,6],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,1024,7,7],f32> +// CHECK: torch.aten.avg_pool2d %[[PADDED_INPUT]], {{.*}} : !torch.vtensor<[1,1024,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1024,1,1],f32> + +func.func @test_averagepool_with_asymmetric_padding(%arg1: !torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1000 : si64, ai.onnx.ml = 3 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, com.ms.internal.nhwc = 1 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { + %1 = torch.operator "onnx.AveragePool"(%arg1) {torch.onnx.auto_pad = "NOTSET", torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> + return %1 : !torch.vtensor<[1,1024,1,1],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_strides_no_padding func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0