Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added requantize function and tests #171

Merged
merged 5 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion quanto/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torch import device as torch_device

from .nn import QModuleMixin, quantize_module


__all__ = ["quantize", "freeze"]
__all__ = ["quantize", "freeze", "requantize"]


def set_module_by_name(parent_module, name, child_module):
Expand Down Expand Up @@ -48,6 +50,22 @@ def quantize(model, modules=None, **kwargs):
del param


def requantize(model, state_dict):
# you shouldn't move models that were distributed with accelerate
if hasattr(model, "hf_device_map"):
raise ValueError(
"Model is distributed with accelerate, cannot requantize. Please use an un-distributed model."
)
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved

# empty the model params by moving to the meta device, then quantize
model.to(torch_device("meta"))
quantize(model)

# move the quantized but empty model to cpu then load the state_dict
model.to_empty(device=torch_device("cpu"))
model.load_state_dict(state_dict)
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved


def freeze(model):
for name, m in model.named_modules():
if isinstance(m, QModuleMixin):
Expand Down
101 changes: 101 additions & 0 deletions test/model/test_quantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
qint4,
qint8,
quantize,
requantize,
safe_load,
safe_save,
)
Expand Down Expand Up @@ -204,3 +205,103 @@ def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device
def test_quantize_mlp_wrong_optimizer(weights, optimizer, device):
with pytest.raises(ValueError):
_test_quantize_mlp(weights, None, optimizer, False, device)


def _test_requantize_mlp(weights, activations, optimizer, frozen, device, serialization):
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved
model = MLP(32, 10, 128).to(device)
output = get_outputs(model, 1, 32, device)
quantize(model, weights=weights, activations=activations, optimizer=optimizer)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
requantize(model, state_dict)
model.to(device)

with Calibration():
qoutput = get_outputs(model, 1, 32, device)
if activations is not None:
assert isinstance(qoutput, QTensor)
# Don't expect more than a 0.99 similarity
assert_similar(output, qoutput, atol=1e-2)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_weights_only(weights, frozen, device, serialization):
_test_requantize_mlp(weights, None, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_int8_activations(weights, frozen, device, serialization):
_test_requantize_mlp(weights, qint8, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize(
"activations",
[None, qint8, qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"],
)
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantize_mlp_float8_activations(weights, activations, frozen, device, serialization):
_test_requantize_mlp(weights, activations, None, frozen, device, serialization)


@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_serialize_requantized_mlp(weights, dtype, serialization, device):
if dtype == torch.float16 and device.type == "cpu":
pytest.skip("Matrix multiplication is not supported for float16 on CPU")
input_features = 32
hidden_features = 10
output_features = 128
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
quantize(model, weights=weights)
inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device)
with Calibration():
model(inputs)
freeze(model)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
model_reloaded = MLP(input_features, hidden_features, output_features)
requantize(model_reloaded, state_dict)
model_reloaded.to(device)
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
assert torch.equal(module_reloaded.weight._data, module.weight._data)
calmitchell617 marked this conversation as resolved.
Show resolved Hide resolved
assert torch.equal(module_reloaded.weight._scale, module.weight._scale)
assert torch.equal(module_reloaded.input_scale, module.input_scale)
assert torch.equal(module_reloaded.output_scale, module.output_scale)


@pytest.mark.skip_device("cpu")
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"])
@pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"])
def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, serialization):
# We might not start from a clean state
input_features = 1024
hidden_features = 2048
output_features = 1024
model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
full_precision_memory = get_device_memory(device)
quantize(model, weights=weights)
freeze(model)
quantized_memory = get_device_memory(device)
assert quantized_memory < full_precision_memory
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
# Free device memory
del model
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)
requantize(reloaded_model, state_dict)
# Free device memory
del state_dict
reloaded_model.to(device)
requantized_memory = get_device_memory(device)
assert requantized_memory <= quantized_memory
Loading