Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add support for asymmetric padding for Onnx.AveragePool op #3923

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ LogicalResult createTorchPermuteOp(OpBinder binder,
SmallVector<int64_t> permuteDims,
Value &permuted);

// Checks the validity of pooling parameters and stores them in the respective
// vector.
LogicalResult checkAndGetOnnxPoolingOpParameters(
OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype,
std::string autoPad, int64_t spatialRank, Value &input,
SmallVectorImpl<int64_t> &kernelSizeInts,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
112 changes: 26 additions & 86 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,107 +456,47 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
patterns.onOp(
"AveragePool", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
SmallVector<int64_t> dilations;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}

Torch::ValueTensorType resultType;
Value operand;
bool ceilMode, countIncludePad;
int64_t ceilMode, countIncludePad;
Copy link

@tuukkjs tuukkjs Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change ceilMode and countIncludePad from bool to int64_t?

std::string autoPad;
if (binder.tensorOperand(operand) ||
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) ||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
binder.s64IntegerAttr(countIncludePad, "count_include_pad", 0) ||
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
binder.tensorResultType(resultType))
return failure();
return rewriter.notifyMatchFailure(
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/"
"resultType bind failure");

// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;

SmallVector<int64_t> kernel, padding, strides;
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) {
return failure();
}
if (kernel.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "kernel list size does not match the number of axes");
}
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
return failure();
}
if (padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op,
"padding list size does not match twice the number of axes");
}
if (binder.s64IntegerArrayAttr(
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
if (strides.size() != 1 && strides.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
}

SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
for (int64_t i : kernel) {
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add e2e tests in shark-testsuite if the change work. I don't think torch to linalg support this pattern.

// axes x.
int64_t paddingSizeHalf = padding.size() / 2;
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
// Check if onnx padding attribute is symmetric.
if (padding[i] != padding[i + paddingSizeHalf])
return rewriter.notifyMatchFailure(
binder.op, "onnx padding attribute is not symmetric");
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
for (int64_t i : strides) {
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
SmallVector<int64_t> kernel, padding, strides, dilations,
stridesDilations;
if (failed(checkAndGetOnnxPoolingOpParameters(
binder, rewriter, resultType.getDtype(), autoPad,
/*spatialRank=*/rank - 2,
/*input=*/operand, kernel, strides, padding, dilations)))
return rewriter.notifyMatchFailure(binder.op,
"invalid pooling parameters");

// No dilations attribute in pytorch avgpool op, so use this trick to
// encode dilation into strides. Then in the following torchtolinalg
// lowering, decode strides into strides + dilation.
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
// hence we use the trick of encoding dilation into strides. Then,
// during the torch->linalg lowering of the `AvgPool` op we decode the
// `strides` arg into strides values followed by dilation like:
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
if (binder.s64IntegerArrayAttr(
dilations, "dilations",
llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
for (auto dilation : dilations) {
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
}
stridesDilations = strides;
stridesDilations.append(dilations);

Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstKernel);
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesDilationsList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstStridesDilations);
createConstantIntList(binder, rewriter, stridesDilations);
Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
Expand Down
128 changes: 14 additions & 114 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,138 +1124,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
});
patterns.onOp(
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return rewriter.notifyMatchFailure(binder.op,
"auto_pad bind failure");

Torch::ValueTensorType resultTypeOut;
Value operand;
int64_t ceilMode, storageOrder;
// TODO: Add support for indices output and storage_order
std::string autoPad;
if (binder.tensorOperand(operand) ||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
binder.tensorResultTypeAtIndex(resultTypeOut, 0))
return rewriter.notifyMatchFailure(
binder.op,
"operand/ceil_mode/storage_order/resultType bind failure");
binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType "
"bind failure");
// TODO: Add support for storage_order
if (storageOrder != 0)
return rewriter.notifyMatchFailure(
binder.op, "storage_order setting is not supported.");

// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
int64_t rank = *maybeRank;
int64_t spatial = rank - 2;
unsigned rank = *maybeRank;

SmallVector<int64_t> kernel, padding, strides, dilations;
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
SmallVector<int64_t> kernel, padding, strides, dilations,
stridesDilations;
if (failed(checkAndGetOnnxPoolingOpParameters(
Copy link

@tuukkjs tuukkjs Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about countIncludePad = false? If we pad using AtenConstantPadNdOp and after that do AtenAvgPoolOp don’t we lose the ability to support countIncludePad = false?

binder, rewriter, resultTypeOut.getDtype(), autoPad,
/*spatialRank=*/rank - 2,
/*input=*/operand, kernel, strides, padding, dilations)))
return rewriter.notifyMatchFailure(binder.op,
"kernel_shape bind failure");
if (kernel.size() != static_cast<size_t>(spatial))
return rewriter.notifyMatchFailure(
binder.op, "kernel list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
if (!padding.empty() &&
padding.size() != static_cast<size_t>(2 * spatial))
return rewriter.notifyMatchFailure(
binder.op, "padding list must contain (begin,end) pair for each "
"spatial axis");
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
return rewriter.notifyMatchFailure(binder.op,
"dilations bind failure");

// set default padding
if (padding.empty())
padding.resize(spatial, 0);
if (strides.empty())
strides.resize(spatial, 1);
if (dilations.empty())
dilations.resize(spatial, 1);

auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());

// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
if (autoPad != "NOTSET" && autoPad != "VALID") {
const bool isSameLower = autoPad == "SAME_LOWER";
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatial);
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
const int64_t dilatedKernelSize =
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
strides[dimIdx] -
1) *
strides[dimIdx] +
dilatedKernelSize - inputShape[dimIdx + 2];
totalPad = totalPad >= 0 ? totalPad : 0;
padding[dimIdx] =
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
}
}

// If the padding is symmetric we can push the padding operation to the
// torch operator.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
bool equal = true;
for (int i = 0; i < spatial; ++i) {
equal = equal && (padding[i] == padding[i + spatial]);
}
if (equal)
padding.resize(spatial);
}

// Torch pool operators require equal padding on each size of each
// dimension so we materialize the padding behavior explicitly and set
// the padding to 0.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
for (int i = 0; i < spatial; ++i) {
paddedShape[i + 2] += padding[i] + padding[i + spatial];
shuffledPadding[2 * i] = padding[spatial - i - 1];
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
}

Value shuffledPaddingList =
createConstantIntList(binder, rewriter, shuffledPadding);
Value zero;
if (isa<FloatType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(
std::numeric_limits<double>::lowest()));
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
std::numeric_limits<int64_t>::lowest()));
}

auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
paddedShape, operandTy.getDtype());
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
zero);
padding.clear();
padding.resize(spatial, 0);
}
"invalid pooling parameters");

Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
Value paddingList = createConstantIntList(binder, rewriter, padding);
Expand Down
Loading
Loading