From 5174914781a557dcbc5ae7bbed702df95bd71200 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 Sep 2023 05:49:19 +0000 Subject: [PATCH] Do not change torch module types --- examples/llama/model_test.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index b07d5d3fc..a57e766a2 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -65,13 +65,13 @@ def run_ark( def run_pt( module: torch.nn.Module, - state_dict: Dict[str, np.ndarray], + state_dict: Dict[str, torch.Tensor], inputs: list = [], ) -> List[np.ndarray]: # Update the current state_dict with the given one cur_state_dict = module.state_dict() for k, v in state_dict.items(): - cur_state_dict[k] = torch.from_numpy(v) + cur_state_dict[k] = v module.load_state_dict(cur_state_dict) # Load input data to GPU @@ -110,35 +110,30 @@ def test_module( if os.path.exists(pth_path): prefix = module_name_prefix + "." if module_name_prefix else "" # Load the state_dict from the given path - state_dict = torch.load(pth_path) - state_dict = { - k[len(prefix) :]: v.float().numpy().astype(dtype) - for k, v in state_dict.items() + state_dict_pt = torch.load(pth_path) + state_dict_pt = { + k[len(prefix) :]: v + for k, v in state_dict_pt.items() if k[len(prefix) :] in param_names and k.startswith(prefix) } + state_dict_ark = { + k: v.float().numpy().astype(dtype) + for k, v in state_dict_pt.items() + } else: # Create a random state_dict - state_dict = { - k: np.random.uniform(low=-0.1, high=0.1, size=v.size()).astype( - dtype - ) - for k, v in module_ark.params_dict().items() - } + raise NotImplementedError # Run the ARK module output_ark = run_ark( - module_ark, state_dict, ark_inputs if ark_inputs else inputs + module_ark, state_dict_ark, ark_inputs if ark_inputs else inputs ) # PyTorch module module_pt: torch.nn.Module = module_class_pt(*module_args_pt) - # - for _, param in module_pt.named_parameters(): - param.data = param.data.to(numpy_dtype_to_torch_dtype[dtype]) - # Run the PyTorch module - output_pt = run_pt(module_pt, state_dict, inputs) + output_pt = run_pt(module_pt, state_dict_pt, inputs) # Compare the outputs eps = np.finfo(np.float64).eps