Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights Still in FP32 after Quantization #347

Open
ClaraLovesFunk opened this issue Nov 6, 2024 · 6 comments
Open

Weights Still in FP32 after Quantization #347

ClaraLovesFunk opened this issue Nov 6, 2024 · 6 comments

Comments

@ClaraLovesFunk
Copy link

ClaraLovesFunk commented Nov 6, 2024

Dear quanto folks,

I implemented quantization as suggested in your coding example quantize_sst2_model.py. When printing the datatypes of the parameters, I found that after quantization all the weights remained in float32. Do you have any explaination to this?

And also do you have any explainations, why i can't use bigger batch sizes when applying quantization of both weights and activations? I used PubMedBERT for Huggingface, fine-tuned it myself and applied static quantization (see code below).

And do you know why inference speed significantly slows down when i use the reloaded statically quantized model (code below) as opposed to the directly statically quantized model? I again followed the instructions of the coding example

Any help greatly appreciated since I'm just wrapping up my soon due master thesis about this <3
Clara

Direct Static Quantiation:

weights = qint8
activations = qint8

model_quantized_static = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(label_mapping)).to(device)
quantize(model_quantized_static, weights=weights, activations=activations)
if activations is not None:
    print("Calibrating ...")
    with Calibration():
        evaluate_model(model_quantized_static, dataset_val, device, batch_size = 64)

freeze(model_quantized_static)
# Check the data type of model parameters
for name, param in model_quantized_static.named_parameters():
    print(f"Parameter: {name}, Data Type: {param.dtype}")

Reloading statically quantized model:

model_quantized_reloaded = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(label_mapping)).to(device)
quantize(model_quantized_reloaded, weights=weights, activations=activations)
state_dict = torch.load(model_quantized_path)
model_quantized_reloaded.load_state_dict(state_dict)
freeze(model_quantized_reloaded)
@ClaraLovesFunk
Copy link
Author

I just tested your example file quantize_sst2_model.py and printed the parameters of the reloaded model and also there all the parameters are still in float32.

for name, param in model_reloaded.named_parameters():
    print(f"Parameter: {name}, Data Type: {param.dtype}")

Float model
872 sentences evaluated in 2.08 s. accuracy = 0.9105504587155964
Calibrating ...
872 sentences evaluated in 3.12 s. accuracy = 0.893348623853211
Quantized model (w: quanto.qint8, a: quanto.qint8)
872 sentences evaluated in 1.85 s. accuracy = 0.8979357798165137
:68: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(b)
Serialized quantized model
872 sentences evaluated in 1.98 s. accuracy = 0.8864678899082569
Parameter: distilbert.embeddings.word_embeddings.weight, Data Type: torch.float32
Parameter: distilbert.embeddings.position_embeddings.weight, Data Type: torch.float32
Parameter: distilbert.embeddings.LayerNorm.weight, Data Type: torch.float32
Parameter: distilbert.embeddings.LayerNorm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.0.output_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.1.output_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.2.output_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.3.output_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.4.output_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.q_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.q_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.k_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.k_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.v_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.v_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.out_lin.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.attention.out_lin.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.sa_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.sa_layer_norm.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.ffn.lin1.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.ffn.lin1.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.ffn.lin2.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.ffn.lin2.bias, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.output_layer_norm.weight, Data Type: torch.float32
Parameter: distilbert.transformer.layer.5.output_layer_norm.bias, Data Type: torch.float32
Parameter: pre_classifier.weight, Data Type: torch.float32
Parameter: pre_classifier.bias, Data Type: torch.float32
Parameter: classifier.weight, Data Type: torch.float32
Parameter: classifier.bias, Data Type: torch.float32

@ClaraLovesFunk ClaraLovesFunk changed the title Static Quantization Not Implemented Static Quantization - Weights Still in FP32 Nov 6, 2024
@ClaraLovesFunk ClaraLovesFunk changed the title Static Quantization - Weights Still in FP32 Weights Still in FP32 after Quantization Nov 8, 2024
@dacorvo
Copy link
Collaborator

dacorvo commented Nov 10, 2024

@ClaraLovesFunk thank you for your feedback. The parameters dtype is still float32, but if you check their type, you will see that they are now QTensor subtypes instead of Tensor. QTensor subtypes preserve the external dtype but their internal data is quantized. You can check the qtype property to verify if it is correct.

@ClaraLovesFunk
Copy link
Author

Thank you so much for the explanation, David! Will do.

@ClaraLovesFunk
Copy link
Author

Do you maybe also have an explanation, why i can't use bigger batch sizes after applying quantization and veryfing my model shrinked from 413.44 to 169.11 MB?

@LianShuaiLong
Copy link

The parameters dtype is still float32, but if you check their type, you will see that they are now QTensor subtypes instead of Tensor. QTensor subtypes preserve the external dtype but their internal data is quantized. You can check the qtype property to verify if it is correct.

How can i get param's dtype or qtype? param.qtype?

@ClaraLovesFunk
Copy link
Author

ClaraLovesFunk commented Dec 9, 2024

Heyy Lian,

you can check the datatype of model weights with:

for name, param in model.named_parameters():
    print(f"{name}: {param.dtype}")

I actually did not check the qtype tho, but gen ai suggests:

print(f"Quantized Tensor Type (qtype): {q_tensor.qscheme()}")

Cheers <3
Clara

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants