Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add per channel quantization support for Onnx.QLinearConv op #3917

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

This commit extends the OnnxToTorch Lowering for Onnx.QLinearConv op by adding the support for per channel quantization for the weight argument.

Signed-off-by: Vivek Khandelwal [email protected]

Copy link
Collaborator

@jinchen62 jinchen62 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Vivek. I think you need to modify some of the output quantization handling in the per-channel case. Maybe store a bool that tracks if we are in the per-channel case so you can reuse it for the output.

It looks like this conversion automatically fuses the input and weight quantization with the convolution, so the only thing that fuse-quantized-ops is going to do is quantize the bias (which won't work currently in the per-channel case). I think it is fine, but we won't be able to check correctness e2e until we address the per-channel quantization, unfortunately.

return failure();
auto weightShape = weightTy.getSizes();
auto weightScaleShape = weightScaleTy.getSizes();
Value weightScaleScalar = extract(weightScale);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract won't work if the weight scale isn't a single element. I'd put this in the else block below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you use this below to handle the quantization of the output, but this must also be per-channel if the weight is per-channel.

Value weightScaleScalar = extract(weightScale);
if (weightScaleShape.size() == 1 &&
weightScaleShape[0] != Torch::kUnknownSize &&
weightScaleShape[0] == weightShape[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally check that weightShape[0] != 1 since we don't want to lower to per-channel when there is only one channel.

Comment on lines +380 to +383
} else {
weightZp = extract(weightZp);
weight = makePerTensor(weight, weightScaleScalar, weightZp);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a nit, but I'd prefer an else if here with the conditions for makePerTensor, and then an else branch with an unreachable, just to be very clear about what assumptions are being made in each case.


cTy = rewriter.getType<Torch::ValueTensorType>(
outputTy = rewriter.getType<Torch::ValueTensorType>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this is a bit subtle. The last optional input for this op is the int32 bias, assumed to be quantized via the product of input and weight scales. This implies that the quantization of the bias (and also the output of the convolution) is also per-channel if the weight was per-channel quantized. This part is fine, but we will need to case out the logic below.


Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to possibly be float x tensor mul.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants