Skip to content

Commit

Permalink
created file for requant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
calmitchell617 committed Apr 12, 2024
1 parent 453a6a7 commit e82a862
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 101 deletions.
101 changes: 0 additions & 101 deletions test/model/test_quantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
qint4,
qint8,
quantize,
requantize,
safe_load,
safe_save,
)
Expand Down Expand Up @@ -205,103 +204,3 @@ def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device
def test_quantize_mlp_wrong_optimizer(weights, optimizer, device):
with pytest.raises(ValueError):
_test_quantize_mlp(weights, None, optimizer, False, device)


def _test_requantize_mlp(weights, activations, optimizer, frozen, device, serialization):
model = MLP(32, 10, 128).to(device)
output = get_outputs(model, 1, 32, device)
quantize(model, weights=weights, activations=activations, optimizer=optimizer)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
requantize(model, state_dict)
model.to(device)

with Calibration():
qoutput = get_outputs(model, 1, 32, device)
if activations is not None:
assert isinstance(qoutput, QTensor)
# Don't expect more than a 0.99 similarity
assert_similar(output, qoutput, atol=1e-2)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_weights_only(weights, frozen, device, serialization):
_test_requantize_mlp(weights, None, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_int8_activations(weights, frozen, device, serialization):
_test_requantize_mlp(weights, qint8, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize(
"activations",
[None, qint8, qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"],
)
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_float8_activations(weights, activations, frozen, device, serialization):
_test_requantize_mlp(weights, activations, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_serialize_requantized_mlp(weights, dtype, serialization, device):
if dtype == torch.float16 and device.type == "cpu":
pytest.skip("Matrix multiplication is not supported for float16 on CPU")
input_features = 32
hidden_features = 10
output_features = 128
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
quantize(model, weights=weights)
inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device)
with Calibration():
model(inputs)
freeze(model)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
model_reloaded = MLP(input_features, hidden_features, output_features)
requantize(model_reloaded, state_dict)
model_reloaded.to(device)
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert torch.equal(module_reloaded.weight._data, module.weight._data)
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)


@pytest.mark.skip_device("cpu")
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization):
# We might not start from a clean state
input_features = 1024
hidden_features = 2048
output_features = 1024
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
full_precision_memory = get_device_memory(device)
quantize(model, weights=weights)
freeze(model)
quantized_memory = get_device_memory(device)
assert quantized_memory < full_precision_memory
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
# Free device memory
del model
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)
requantize(reloaded_model, state_dict)
# Free device memory
del state_dict
reloaded_model.to(device)
requantized_memory = get_device_memory(device)
assert requantized_memory <= quantized_memory
82 changes: 82 additions & 0 deletions test/model/test_requantize_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from helpers import get_device_memory, random_tensor

from quanto import (
Calibration,
freeze,
qint8,
quantize,
requantize,
)
from quanto.nn import QModuleMixin
from test_quantize_mlp import MLP, save_and_reload_state_dict

@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_serialize_requantized_mlp(weights, dtype, serialization, device):
if dtype == torch.float16 and device.type == "cpu":
pytest.skip("Matrix multiplication is not supported for float16 on CPU")
input_features = 32
hidden_features = 10
output_features = 128
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
quantize(model, weights=weights)
inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device)
with Calibration():
model(inputs)
freeze(model)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
model_reloaded = MLP(input_features, hidden_features, output_features)
requantize(model_reloaded, state_dict)
model_reloaded.to(device)
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert torch.equal(module_reloaded.weight._data, module.weight._data)
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)


@pytest.mark.skip_device("cpu")
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization):
# We might not start from a clean state
input_features = 1024
hidden_features = 2048
output_features = 1024
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
full_precision_memory = get_device_memory(device)
quantize(model, weights=weights)
freeze(model)
quantized_memory = get_device_memory(device)
assert quantized_memory < full_precision_memory
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
# Free device memory
del model
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)
requantize(reloaded_model, state_dict)
# Free device memory
del state_dict
reloaded_model.to(device)
requantized_memory = get_device_memory(device)
assert requantized_memory <= quantized_memory

0 comments on commit e82a862

Please sign in to comment.