Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 23, 2023
1 parent b70ca72 commit da7b36c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 188 deletions.
10 changes: 2 additions & 8 deletions examples/llama/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ def __init__(

self.runtime: ark.Runtime = None

def launch(
self,
pth_path: str,
tok_path: str,
):
def launch(self, pth_path: str, tok_path: str):
# Load a pretrained tokenizer
self.tokenizer = Tokenizer(model_path=tok_path)
self.args.vocab_size = self.tokenizer.n_words
Expand Down Expand Up @@ -163,9 +159,7 @@ def run(self, prompt: str):

gen.launch(args.pth_path, args.tok_path)

prompt_list = [
"Where is the capital of France?",
]
prompt_list = ["Where is the capital of France?"]
for i, prompt in enumerate(prompt_list):
output = gen.run(prompt)
print(f"---\nPrompt[{i}]: {prompt}\nOutput[{i}]: {output}")
3 changes: 1 addition & 2 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ def forward(self, x):
for i in range(len(output_shape)):
output_shape_bytes *= output_shape[i]
output_parallel_reshape = ark.reshape(
output_parallel,
[output_shape_bytes],
output_parallel, [output_shape_bytes]
)
output_reshape = ark.all_reduce(
output_parallel_reshape, self.local_rank, self.world_size
Expand Down
35 changes: 7 additions & 28 deletions python/ark/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,13 @@


_REGISTRY_DATA_TYPE = {
"fp32": {
"np": numpy.float32,
"doc": """32-bit floating point.""",
},
"fp16": {
"np": numpy.float16,
"doc": """16-bit floating point.""",
},
"bf16": {
"np": None,
"doc": """bfloat16 floating point.""",
},
"int32": {
"np": numpy.int32,
"doc": """32-bit signed integer.""",
},
"uint32": {
"np": numpy.uint32,
"doc": """32-bit unsigned integer.""",
},
"int8": {
"np": numpy.int8,
"doc": """8-bit signed integer.""",
},
"uint8": {
"np": numpy.uint8,
"doc": """8-bit unsigned integer.""",
},
"fp32": {"np": numpy.float32, "doc": """32-bit floating point."""},
"fp16": {"np": numpy.float16, "doc": """16-bit floating point."""},
"bf16": {"np": None, "doc": """bfloat16 floating point."""},
"int32": {"np": numpy.int32, "doc": """32-bit signed integer."""},
"uint32": {"np": numpy.uint32, "doc": """32-bit unsigned integer."""},
"int8": {"np": numpy.int8, "doc": """8-bit signed integer."""},
"uint8": {"np": numpy.uint8, "doc": """8-bit unsigned integer."""},
"byte": {
"np": numpy.ubyte,
"doc": """
Expand Down
Loading

0 comments on commit da7b36c

Please sign in to comment.