From bdb6cfb9b237f91e763b7e2394594236f18fffd6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 18 Sep 2023 13:28:43 +0000 Subject: [PATCH] weights loading --- examples/llama/model_test.py | 62 +++++++++++++++++++++++++++++------- python/ark/module.py | 44 +++++++++++++++---------- 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 8c90a9d56..ff18eebb8 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -16,6 +16,8 @@ from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B +pth_path: str = "/mnt/7B/consolidated.00.pth" + numpy_dtype_to_torch_dtype: dict = { np.float16: torch.float16, np.float32: torch.float32, @@ -98,16 +100,28 @@ def test_module( module_class_pt: torch.nn.Module, module_args_pt: list, ark_inputs: List[np.ndarray] = [], # used when ARK needs different inputs + module_name_prefix: str = "", ): # ARK module module_ark: ark.Module = module_class_ark(*module_args_ark) - # Create a random state_dict - state_dict = module_ark.state_dict() - state_dict = { - k: np.random.uniform(low=-0.1, high=0.1, size=v.shape).astype(dtype) - for k, v in state_dict.items() - } + param_names = set(module_ark.params_dict().keys()) + + if os.path.exists(pth_path): + prefix = module_name_prefix + "." if module_name_prefix else "" + # Load the state_dict from the given path + state_dict = torch.load(pth_path) + state_dict = { + k[len(prefix) :]: v.float().numpy().astype(dtype) + for k, v in state_dict.items() + if k in param_names and k.startswith(prefix) + } + else: + # Create a random state_dict + state_dict = { + k: np.random.uniform(low=-0.1, high=0.1, size=v.shape).astype(dtype) + for k, v in module_ark.params_dict().items() + } # Run the ARK module output_ark = run_ark( @@ -167,9 +181,14 @@ def test_rmsnorm( inputs, dtype, module_class_ark=model_ark.RMSNorm, - module_args_ark=[args.dim, 1e-6, ark.DataType.from_numpy(dtype)], + module_args_ark=[ + args.dim, + args.norm_eps, + ark.DataType.from_numpy(dtype), + ], module_class_pt=model_pt.RMSNorm, module_args_pt=[args.dim], + module_name_prefix="norm", ) @@ -185,7 +204,9 @@ def test_row_parallel_linear( # Create random input data inputs = [ np.random.uniform( - low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim) + low=-0.1, + high=0.1, + size=(batch_size, seq_len, args.dim // args.n_heads * args.n_heads), ).astype(dtype) ] @@ -195,14 +216,20 @@ def test_row_parallel_linear( dtype, module_class_ark=model_ark.RowParallelLinear, module_args_ark=[ - args.dim, + args.dim // args.n_heads * args.n_heads, args.dim, ark.DataType.from_numpy(dtype), 0, 1, ], module_class_pt=fairscale.nn.model_parallel.RowParallelLinear, - module_args_pt=[args.dim, args.dim, False, lambda x: x], + module_args_pt=[ + args.dim // args.n_heads * args.n_heads, + args.dim, + False, + lambda x: x, + ], + module_name_prefix="layers.0.attention.wo", ) @@ -229,13 +256,19 @@ def test_column_parallel_linear( module_class_ark=model_ark.ColumnParallelLinear, module_args_ark=[ args.dim, - args.dim, + args.dim // args.n_heads * args.n_heads, ark.DataType.from_numpy(dtype), 0, 1, ], module_class_pt=fairscale.nn.model_parallel.ColumnParallelLinear, - module_args_pt=[args.dim, args.dim, False, lambda x: x], + module_args_pt=[ + args.dim, + args.dim // args.n_heads * args.n_heads, + False, + lambda x: x, + ], + module_name_prefix="layers.0.attention.wq", ) @@ -286,6 +319,7 @@ def test_attention( module_class_pt=model_pt.Attention, module_args_pt=[args], ark_inputs=ark_inputs, + module_name_prefix="layers.0.attention", ) @@ -326,6 +360,7 @@ def test_transformer_block( module_class_pt=model_pt.TransformerBlock, module_args_pt=[0, args], ark_inputs=ark_inputs, + module_name_prefix="layers.0", ) @@ -402,6 +437,9 @@ def test(args, batch_size, seq_len, dtype, world_size): # Default from HuggingFace args.vocab_size = 32000 + # For debugging + # args.n_layers = 8 + # Verify the configurations assert batch_size <= args.max_batch_size assert seq_len <= args.max_seq_len diff --git a/python/ark/module.py b/python/ark/module.py index 7e6a0a1ef..76635e8d3 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -48,31 +48,43 @@ def register_parameter(self, name: str, param: Tensor) -> None: raise TypeError("param must be a Tensor") self.parameters[name] = param - def load_state_dict(self, state_dict, prefix=""): + def params_dict(self, prefix=""): + params_dict = {} + for name, module in self.sub_modules.items(): + if module is not None: + params_dict.update( + module.params_dict(prefix=prefix + name + ".") + ) + for name, param in self.parameters.items(): + params_dict[prefix + name] = param + return params_dict + + def load_state_dict( + self, state_dict: Dict[str, np.ndarray], prefix: str = "" + ): """ Loads a model from a state_dict and copy the parameters to the device GPU. Must be called after the executor is launched. """ logging.info("Loading model from state_dict") - for name, module in self.sub_modules.items(): - if module is not None: - module.load_state_dict(state_dict, prefix=prefix + name + ".") - for name, param in self.parameters.items(): - param.from_numpy(state_dict[prefix + name]) - def state_dict(self, prefix="") -> Dict[str, np.ndarray]: + all_keys = set(state_dict.keys()) + pd = self.params_dict(prefix) + for name, param in pd.items(): + param.from_numpy(state_dict[name]) + all_keys.remove(name) + if all_keys: + logging.warning( + f"{len(all_keys)} unused parameter(s) in state_dict" + ) + + def state_dict(self, prefix: str = "") -> Dict[str, np.ndarray]: """ - Copies the parameters from the device GPU to the host and saves the model to a state_dict. + Copies the parameters from the device GPU to the host and saves the + model to a state_dict. Must be called after the executor is launched. """ - state_dict = {} - for name, module in self.sub_modules.items(): - if module is not None: - state_dict.update(module.state_dict(prefix=prefix + name + ".")) - for name, param in self.parameters.items(): - param_np = param.to_numpy() - state_dict[prefix + name] = param_np - return state_dict + return {k: v.to_numpy() for k, v in self.params_dict(prefix).items()} def forward(self, *args: Any, **kwargs: Any) -> Any: ...