Skip to content

Commit

Permalink
Add perf test
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 13, 2023
1 parent fdf14e2 commit 1768bcf
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions examples/llama/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import model as model_ark
import numpy as np
from typing import Dict, List
from dataclasses import dataclass
from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B
from generator import precompute_freqs_cis

Expand All @@ -27,11 +28,16 @@
np.int32: torch.int32,
}

@dataclass
class RunResults:
outputs: List[np.ndarray] = None
runtime: float = 0.0 # in seconds

def run_ark(
module: ark.Module,
state_dict: Dict[str, np.ndarray],
inputs: list = [],
iterations: int = 1,
rank: int = 0,
world_size: int = 1,
) -> List[np.ndarray]:
Expand All @@ -58,18 +64,26 @@ def run_ark(
for tensor, ndarray in zip(tensors, tensor_data):
tensor.from_numpy(ndarray)

start_time = time.time()

# Run the model
runtime.run()
runtime.run(iter=iterations)

end_time = time.time()

if isinstance(output, list) or isinstance(output, tuple):
return [o.to_numpy() for o in output]
return [output.to_numpy()]
outputs = [o.to_numpy() for o in output]
outputs = [output.to_numpy()]

return RunResults(outputs=outputs, runtime=end_time - start_time)


@torch.inference_mode()
def run_pt(
module: torch.nn.Module,
state_dict: Dict[str, torch.Tensor],
inputs: list = [],
iterations: int = 1,
) -> List[np.ndarray]:
# Update the current state_dict with the given one
cur_state_dict = module.state_dict()
Expand All @@ -86,13 +100,20 @@ def run_pt(
# Load the module to GPU
module = module.to("cuda:0")

start_time = time.time()

# Run the module
with torch.no_grad():
output = module(*input_tensors)
for _ in range(iterations):
output = module(*input_tensors)

end_time = time.time()

if isinstance(output, list) or isinstance(output, tuple):
return [o.detach().to("cpu").numpy() for o in output]
return [output.detach().to("cpu").numpy()]
outputs = [o.detach().to("cpu").numpy() for o in output]
outputs = [output.detach().to("cpu").numpy()]

return RunResults(outputs=outputs, runtime=end_time - start_time)


def test_module(
Expand All @@ -103,7 +124,14 @@ def test_module(
module_args_pt: list,
inputs_pt: List[np.ndarray],
module_name_prefix: str = "",
test_thru: bool = False,
test_thru_iterations: int = 100,
):
if test_thru:
print(f"Throughput test (iterations: {test_thru_iterations})")
else:
print(f"Correctness test")

# ARK module
module_ark: ark.Module = module_class_ark(*module_args_ark)

Expand All @@ -127,17 +155,23 @@ def test_module(
raise ValueError(f"Cannot find the given path: {pth_path}")

# Run the ARK module
output_ark = run_ark(module_ark, state_dict_ark, inputs_ark)
res_ark = run_ark(module_ark, state_dict_ark, inputs_ark,
iterations=test_thru_iterations if test_thru else 1)

# PyTorch module
module_pt: torch.nn.Module = module_class_pt(*module_args_pt)

# Run the PyTorch module
output_pt = run_pt(module_pt, state_dict_pt, inputs_pt)
res_pt = run_pt(module_pt, state_dict_pt, inputs_pt,
iterations=test_thru_iterations if test_thru else 1)

if test_thru:
print(f" PyTorch: {res_pt.runtime:.4f} seconds, ARK: {res_ark.runtime:.4f} seconds")
return

# Compare the outputs
eps = np.finfo(np.float64).eps
for i, (o_ark, o_pt) in enumerate(zip(output_ark, output_pt)):
for i, (o_ark, o_pt) in enumerate(zip(res_ark.outputs, res_pt.outputs)):
shape = o_ark.shape
o_ark = o_ark.flatten().astype(np.float64)
o_pt = o_pt.flatten().astype(np.float64)
Expand Down Expand Up @@ -430,13 +464,13 @@ def test(args, batch_size, seq_len, dtype, world_size):
args = ModelArgs7B()
batch_size = 1
seq_len = 1024
dtype = np.float16
dtype = np.float32
world_size = 1

# Default from HuggingFace
args.vocab_size = 32000

# PyTorch model cannot run all layers due to OOM
# Reduce max_seq_len due to OOM from the PyTorch model
args.max_seq_len = 1024

# Verify the configurations
Expand Down

0 comments on commit 1768bcf

Please sign in to comment.