-
Notifications
You must be signed in to change notification settings - Fork 14
/
Export.py
61 lines (52 loc) · 2.08 KB
/
Export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import fire
import sys
import os
import json
from pathlib import Path
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..")
import llama.modeling.Loader as Loader
from ModelParams import ModelParams
import ConvertParamsToOpmx
def main(
ckpt_dir: str,
tokenizer_path: str,
export_path: str,
friendly_gqa: bool = False, # done gqa by repeating key and value by key_value_cache op
fused_qkv: bool = True, # fuse qkv linear
fused_kvcache: bool = True, # fuse key_value_cache and multi_head_attention
fused_ffn_glu: bool = True, # fuse feed forward gate linear unit
auto_causal: bool = True, # causal mask is auto done by attention op, no need to pass additional mask to the model
quantized_cache: bool = True, # 8bit kv cache quantization
cache_layout: int = 0, # change kv cache layout for hardware performance friendly
cache_mode: int = 0, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True
dynamic_batching: bool = True, # use dynamic batching scheduling
):
if not os.path.exists(Path(ckpt_dir) / "opmx_params.json"):
print("Info: opmx_params.json not found, do auto param conversion")
ConvertParamsToOpmx.main(ckpt_dir, tokenizer_path)
with open(Path(ckpt_dir) / "opmx_params.json", "r") as f:
params = json.loads(f.read())
params: ModelParams = ModelParams(**params)
generator = Loader.load(
ckpt_dir, params,
friendly_gqa=friendly_gqa,
fused_qkv=fused_qkv,
fused_kvcache=fused_kvcache,
fused_ffn_glu=fused_ffn_glu,
fused_alibi=False,
auto_causal=auto_causal,
with_rope=True,
with_alibi=False,
quantized_cache=quantized_cache,
cache_layout=cache_layout,
cache_mode=cache_mode,
dynamic_batching=dynamic_batching,
attn_wqkv_bias_term=False,
attn_wo_bias_term=False,
ffn_linear_bias_term=False,
load_to_cpu=True,
rotary_dim=0,
)
generator.export(export_path)
if __name__ == "__main__":
fire.Fire(main)