-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize.py
117 lines (104 loc) · 2.58 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from pathlib import Path
from config.config_utils import load_model_config
import argparse
import torch
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import (
SparseAutoModelForCausalLM,
oneshot,
)
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationType,
QuantizationStrategy,
QuantizationScheme,
)
from transformers import AutoTokenizer
torch.set_default_device("cuda")
parser = argparse.ArgumentParser()
parser.add_argument(
"model_alias",
type=str,
help="The alias of the model to quantize.",
)
parser.add_argument(
"scheme",
type=str,
help="The quantization scheme to use.",
choices=["W8A8", "W4A16"],
)
parser.add_argument(
"--models_dir",
type=str,
default="./quantized_models/",
help="The directory to save the quantized model and tokenizer.",
)
parser.add_argument(
"--config_path",
type=str,
default="./config/models.toml",
help="The path to the model configuration file.",
)
args = parser.parse_args()
model_config = load_model_config(
name=args.model_alias,
config_path=args.config_path,
)
model_name_or_path = model_config.name_or_path
trust_remote_code = model_config.model_kwargs.get("trust_remote_code", False)
model = SparseAutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
)
W8A8 = QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
),
)
W4A16 = QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False,
),
)
schemes = {
"W8A8": W8A8,
"W4A16": W4A16,
}
recipe = QuantizationModifier(
config_groups={
"group_0": schemes[args.scheme],
},
ignore=["lm_head"],
)
save_path = Path(
args.models_dir,
model_name_or_path.split("/")[-1].replace(".", "_") + f"-{args.scheme}",
)
oneshot(
model=model,
recipe=recipe,
output_dir=save_path,
)
tokenizer.save_pretrained(save_path)