diff --git a/test/model/test_quantize_mlp.py b/test/model/test_quantize_mlp.py index 2276eb2f..cb51de96 100644 --- a/test/model/test_quantize_mlp.py +++ b/test/model/test_quantize_mlp.py @@ -31,7 +31,6 @@ qint4, qint8, quantize, - requantize, safe_load, safe_save, ) @@ -205,103 +204,3 @@ def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device def test_quantize_mlp_wrong_optimizer(weights, optimizer, device): with pytest.raises(ValueError): _test_quantize_mlp(weights, None, optimizer, False, device) - - -def _test_requantize_mlp(weights, activations, optimizer, frozen, device, serialization): - model = MLP(32, 10, 128).to(device) - output = get_outputs(model, 1, 32, device) - quantize(model, weights=weights, activations=activations, optimizer=optimizer) - state_dict = save_and_reload_state_dict(model.state_dict(), serialization) - requantize(model, state_dict) - model.to(device) - - with Calibration(): - qoutput = get_outputs(model, 1, 32, device) - if activations is not None: - assert isinstance(qoutput, QTensor) - # Don't expect more than a 0.99 similarity - assert_similar(output, qoutput, atol=1e-2) - - -@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) -@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) -@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) -def test_requantize_mlp_weights_only(weights, frozen, device, serialization): - _test_requantize_mlp(weights, None, None, frozen, device, serialization) - - -@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) -@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) -@pytest.mark.skip_device("mps") -@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) -def test_requantize_mlp_int8_activations(weights, frozen, device, serialization): - _test_requantize_mlp(weights, qint8, None, frozen, device, serialization) - - -@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) -@pytest.mark.parametrize( - "activations", - [None, qint8, qfloat8_e5m2, qfloat8_e4m3fn], - ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"], -) -@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) -@pytest.mark.skip_device("mps") -@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) -def test_requantize_mlp_float8_activations(weights, activations, frozen, device, serialization): - _test_requantize_mlp(weights, activations, None, frozen, device, serialization) - - -@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) -def test_serialize_requantized_mlp(weights, dtype, serialization, device): - if dtype == torch.float16 and device.type == "cpu": - pytest.skip("Matrix multiplication is not supported for float16 on CPU") - input_features = 32 - hidden_features = 10 - output_features = 128 - model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) - quantize(model, weights=weights) - inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device) - with Calibration(): - model(inputs) - freeze(model) - state_dict = save_and_reload_state_dict(model.state_dict(), serialization) - model_reloaded = MLP(input_features, hidden_features, output_features) - requantize(model_reloaded, state_dict) - model_reloaded.to(device) - for name, module in model.named_modules(): - if isinstance(module, QModuleMixin): - module_reloaded = getattr(model_reloaded, name) - assert torch.equal(module_reloaded.weight._data, module.weight._data) - assert torch.equal(module_reloaded.weight._scale, module.weight._scale) - assert torch.equal(module_reloaded.input_scale, module.input_scale) - assert torch.equal(module_reloaded.output_scale, module.output_scale) - - -@pytest.mark.skip_device("cpu") -@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"]) -@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) -def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization): - # We might not start from a clean state - input_features = 1024 - hidden_features = 2048 - output_features = 1024 - model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) - full_precision_memory = get_device_memory(device) - quantize(model, weights=weights) - freeze(model) - quantized_memory = get_device_memory(device) - assert quantized_memory < full_precision_memory - state_dict = save_and_reload_state_dict(model.state_dict(), serialization) - # Free device memory - del model - reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype) - requantize(reloaded_model, state_dict) - # Free device memory - del state_dict - reloaded_model.to(device) - requantized_memory = get_device_memory(device) - assert requantized_memory <= quantized_memory diff --git a/test/model/test_requantize_mlp.py b/test/model/test_requantize_mlp.py new file mode 100644 index 00000000..31848762 --- /dev/null +++ b/test/model/test_requantize_mlp.py @@ -0,0 +1,82 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from helpers import get_device_memory, random_tensor + +from quanto import ( + Calibration, + freeze, + qint8, + quantize, + requantize, +) +from quanto.nn import QModuleMixin +from test_quantize_mlp import MLP, save_and_reload_state_dict + +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) +def test_serialize_requantized_mlp(weights, dtype, serialization, device): + if dtype == torch.float16 and device.type == "cpu": + pytest.skip("Matrix multiplication is not supported for float16 on CPU") + input_features = 32 + hidden_features = 10 + output_features = 128 + model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) + quantize(model, weights=weights) + inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device) + with Calibration(): + model(inputs) + freeze(model) + state_dict = save_and_reload_state_dict(model.state_dict(), serialization) + model_reloaded = MLP(input_features, hidden_features, output_features) + requantize(model_reloaded, state_dict) + model_reloaded.to(device) + for name, module in model.named_modules(): + if isinstance(module, QModuleMixin): + module_reloaded = getattr(model_reloaded, name) + assert torch.equal(module_reloaded.weight._data, module.weight._data) + assert torch.equal(module_reloaded.weight._scale, module.weight._scale) + assert torch.equal(module_reloaded.input_scale, module.input_scale) + assert torch.equal(module_reloaded.output_scale, module.output_scale) + + +@pytest.mark.skip_device("cpu") +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) +@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"]) +@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) +def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization): + # We might not start from a clean state + input_features = 1024 + hidden_features = 2048 + output_features = 1024 + model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) + full_precision_memory = get_device_memory(device) + quantize(model, weights=weights) + freeze(model) + quantized_memory = get_device_memory(device) + assert quantized_memory < full_precision_memory + state_dict = save_and_reload_state_dict(model.state_dict(), serialization) + # Free device memory + del model + reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype) + requantize(reloaded_model, state_dict) + # Free device memory + del state_dict + reloaded_model.to(device) + requantized_memory = get_device_memory(device) + assert requantized_memory <= quantized_memory