Skip to content

Commit

Permalink
Do not change torch module types
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 26, 2023
1 parent 3fa1550 commit 5174914
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions examples/llama/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def run_ark(

def run_pt(
module: torch.nn.Module,
state_dict: Dict[str, np.ndarray],
state_dict: Dict[str, torch.Tensor],
inputs: list = [],
) -> List[np.ndarray]:
# Update the current state_dict with the given one
cur_state_dict = module.state_dict()
for k, v in state_dict.items():
cur_state_dict[k] = torch.from_numpy(v)
cur_state_dict[k] = v
module.load_state_dict(cur_state_dict)

# Load input data to GPU
Expand Down Expand Up @@ -110,35 +110,30 @@ def test_module(
if os.path.exists(pth_path):
prefix = module_name_prefix + "." if module_name_prefix else ""
# Load the state_dict from the given path
state_dict = torch.load(pth_path)
state_dict = {
k[len(prefix) :]: v.float().numpy().astype(dtype)
for k, v in state_dict.items()
state_dict_pt = torch.load(pth_path)
state_dict_pt = {
k[len(prefix) :]: v
for k, v in state_dict_pt.items()
if k[len(prefix) :] in param_names and k.startswith(prefix)
}
state_dict_ark = {
k: v.float().numpy().astype(dtype)
for k, v in state_dict_pt.items()
}
else:
# Create a random state_dict
state_dict = {
k: np.random.uniform(low=-0.1, high=0.1, size=v.size()).astype(
dtype
)
for k, v in module_ark.params_dict().items()
}
raise NotImplementedError

# Run the ARK module
output_ark = run_ark(
module_ark, state_dict, ark_inputs if ark_inputs else inputs
module_ark, state_dict_ark, ark_inputs if ark_inputs else inputs
)

# PyTorch module
module_pt: torch.nn.Module = module_class_pt(*module_args_pt)

#
for _, param in module_pt.named_parameters():
param.data = param.data.to(numpy_dtype_to_torch_dtype[dtype])

# Run the PyTorch module
output_pt = run_pt(module_pt, state_dict, inputs)
output_pt = run_pt(module_pt, state_dict_pt, inputs)

# Compare the outputs
eps = np.finfo(np.float64).eps
Expand Down

0 comments on commit 5174914

Please sign in to comment.