Skip to content

Commit

Permalink
review: removed accelerate check. also now moving model back to origi…
Browse files Browse the repository at this point in the history
…nal device after requantizing.
  • Loading branch information
calmitchell617 authored and dacorvo committed Apr 18, 2024
1 parent 47ff25d commit 544981d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__
*.egg-info
dist
.venv
build/
10 changes: 5 additions & 5 deletions quanto/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ def quantize(model, modules=None, **kwargs):


def requantize(model, state_dict):
# you shouldn't move models that were distributed with accelerate
if hasattr(model, "hf_device_map"):
raise ValueError(
"Model is distributed with accelerate, cannot requantize. Please use an un-distributed model."
)
# find device that model is on
device = next(model.parameters()).device

# empty the model params by moving to the meta device, then quantize
model.to(torch_device("meta"))
Expand All @@ -60,6 +57,9 @@ def requantize(model, state_dict):
model.to_empty(device=torch_device("cpu"))
model.load_state_dict(state_dict)

# move the model back to the original device
model.to(device)


def freeze(model):
for name, m in model.named_modules():
Expand Down
6 changes: 2 additions & 4 deletions test/model/test_requantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ def test_serialize_requantized_mlp(weights, dtype, serialization, device):
with Calibration():
model(inputs)
freeze(model)
model_reloaded = MLP(input_features, hidden_features, output_features).to(device)
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
model_reloaded = MLP(input_features, hidden_features, output_features)
requantize(model_reloaded, state_dict)
model_reloaded.to(device)
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
module_reloaded = getattr(model_reloaded, name)
Expand Down Expand Up @@ -77,10 +76,9 @@ def test_requantized_mlp_device_memory(weights, dtype, weights_only, device, ser
state_dict = save_and_reload_state_dict(model.state_dict(), serialization)
# Free device memory
del model
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)
reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)
requantize(reloaded_model, state_dict)
# Free device memory
del state_dict
reloaded_model.to(device)
requantized_memory = get_device_memory(device)
assert requantized_memory <= quantized_memory

0 comments on commit 544981d

Please sign in to comment.