Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Fault Running Int8 Quantized Model on GPU #1437

Open
wendywangwwt opened this issue Dec 18, 2024 · 1 comment
Open

Segmentation Fault Running Int8 Quantized Model on GPU #1437

wendywangwwt opened this issue Dec 18, 2024 · 1 comment

Comments

@wendywangwwt
Copy link

Hi! We got into segmentation fault error when trying to run model inference on gpu. Below is a minimal example from the tutorial (link):

import torch
import time

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
input_fp32 = torch.randn(4, 1, 1024, 1024)

time_s = time.time()
with torch.no_grad():
    out = model_fp32(input_fp32)
time_e = time.time()

model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

model_fp32_prepared(input_fp32)

model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

model_int8 = model_int8.to('cuda:0')
input_fp32 = input_fp32.to('cuda:0')

with torch.no_grad():
    out = model_int8(input_fp32)

Output:

Segmentation fault (core dumped)

Inference on CPU is fine for the int8 model. Could someone please advise on the potential reason? Thank you!

@supriyar
Copy link
Contributor

Hi! Looks like you're trying to use the eager mode quantization flow from pytorch core on the fbgemm backend which currently only runs on x86 server CPU backends.

If you're interested in running on GPU, you can check out the usage instructions from torchao https://github.com/pytorch/ao/tree/main/torchao/quantization#a8w8-int8-dynamic-quantization. However, this doesn't support quantizing conv layers yet (only Linear is supported).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants