Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added requantize function and tests #171

Merged
merged 5 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion quanto/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torch import device as torch_device

from .nn import QModuleMixin, quantize_module


__all__ = ["quantize", "freeze"]
__all__ = ["quantize", "freeze", "requantize"]


def set_module_by_name(parent_module, name, child_module):
Expand Down Expand Up @@ -48,6 +50,22 @@ def quantize(model, modules=None, **kwargs):
del param


def requantize(model, state_dict):
# you shouldn't move models that were distributed with accelerate
if hasattr(model, "hf_device_map"):
raise ValueError(
"Model is distributed with accelerate, cannot requantize. Please use an un-distributed model."
)
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved

# empty the model params by moving to the meta device, then quantize
model.to(torch_device("meta"))
quantize(model)

# move the quantized but empty model to cpu then load the state_dict
model.to_empty(device=torch_device("cpu"))
model.load_state_dict(state_dict)
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved


def freeze(model):
for name, m in model.named_modules():
if isinstance(m, QModuleMixin):
Expand Down
86 changes: 86 additions & 0 deletions test/model/test_requantize_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from helpers import get_device_memory, random_tensor
from test_quantize_mlp import MLP, save_and_reload_state_dict

from quanto import (
Calibration,
freeze,
qint8,
quantize,
requantize,
)
from quanto.nn import QModuleMixin


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_serialize_requantized_mlp(weights, dtype, serialization, device):
if dtype == torch.float16 and device.type == "cpu":
pytest.skip("Matrix multiplication is not supported for float16 on CPU")
input_features = 32
hidden_features = 10
output_features = 128
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
quantize(model, weights=weights)
inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device)
with Calibration():
model(inputs)
freeze(model)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
model_reloaded = MLP(input_features, hidden_features, output_features)
requantize(model_reloaded, state_dict)
model_reloaded.to(device)
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert module_reloaded.weight.qtype == module.weight.qtype
assert module_reloaded.weight_qtype == module.weight_qtype
assert module_reloaded.activation_qtype == module.activation_qtype
assert torch.equal(module_reloaded.weight._data, module.weight._data)
dacorvo marked this conversation as resolved.
Show resolved Hide resolved
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)


@pytest.mark.skip_device("cpu")
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization):
# We might not start from a clean state
input_features = 1024
hidden_features = 2048
output_features = 1024
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
full_precision_memory = get_device_memory(device)
quantize(model, weights=weights)
freeze(model)
quantized_memory = get_device_memory(device)
assert quantized_memory < full_precision_memory
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
# Free device memory
del model
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)
requantize(reloaded_model, state_dict)
# Free device memory
del state_dict
reloaded_model.to(device)
requantized_memory = get_device_memory(device)
assert requantized_memory <= quantized_memory
Loading