Skip to content

Commit

Permalink
weights loading
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 18, 2023
1 parent 8532d8d commit bdb6cfb
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 28 deletions.
62 changes: 50 additions & 12 deletions examples/llama/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B


pth_path: str = "/mnt/7B/consolidated.00.pth"

numpy_dtype_to_torch_dtype: dict = {
np.float16: torch.float16,
np.float32: torch.float32,
Expand Down Expand Up @@ -98,16 +100,28 @@ def test_module(
module_class_pt: torch.nn.Module,
module_args_pt: list,
ark_inputs: List[np.ndarray] = [], # used when ARK needs different inputs
module_name_prefix: str = "",
):
# ARK module
module_ark: ark.Module = module_class_ark(*module_args_ark)

# Create a random state_dict
state_dict = module_ark.state_dict()
state_dict = {
k: np.random.uniform(low=-0.1, high=0.1, size=v.shape).astype(dtype)
for k, v in state_dict.items()
}
param_names = set(module_ark.params_dict().keys())

if os.path.exists(pth_path):
prefix = module_name_prefix + "." if module_name_prefix else ""
# Load the state_dict from the given path
state_dict = torch.load(pth_path)
state_dict = {
k[len(prefix) :]: v.float().numpy().astype(dtype)
for k, v in state_dict.items()
if k in param_names and k.startswith(prefix)
}
else:
# Create a random state_dict
state_dict = {
k: np.random.uniform(low=-0.1, high=0.1, size=v.shape).astype(dtype)
for k, v in module_ark.params_dict().items()
}

# Run the ARK module
output_ark = run_ark(
Expand Down Expand Up @@ -167,9 +181,14 @@ def test_rmsnorm(
inputs,
dtype,
module_class_ark=model_ark.RMSNorm,
module_args_ark=[args.dim, 1e-6, ark.DataType.from_numpy(dtype)],
module_args_ark=[
args.dim,
args.norm_eps,
ark.DataType.from_numpy(dtype),
],
module_class_pt=model_pt.RMSNorm,
module_args_pt=[args.dim],
module_name_prefix="norm",
)


Expand All @@ -185,7 +204,9 @@ def test_row_parallel_linear(
# Create random input data
inputs = [
np.random.uniform(
low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim)
low=-0.1,
high=0.1,
size=(batch_size, seq_len, args.dim // args.n_heads * args.n_heads),
).astype(dtype)
]

Expand All @@ -195,14 +216,20 @@ def test_row_parallel_linear(
dtype,
module_class_ark=model_ark.RowParallelLinear,
module_args_ark=[
args.dim,
args.dim // args.n_heads * args.n_heads,
args.dim,
ark.DataType.from_numpy(dtype),
0,
1,
],
module_class_pt=fairscale.nn.model_parallel.RowParallelLinear,
module_args_pt=[args.dim, args.dim, False, lambda x: x],
module_args_pt=[
args.dim // args.n_heads * args.n_heads,
args.dim,
False,
lambda x: x,
],
module_name_prefix="layers.0.attention.wo",
)


Expand All @@ -229,13 +256,19 @@ def test_column_parallel_linear(
module_class_ark=model_ark.ColumnParallelLinear,
module_args_ark=[
args.dim,
args.dim,
args.dim // args.n_heads * args.n_heads,
ark.DataType.from_numpy(dtype),
0,
1,
],
module_class_pt=fairscale.nn.model_parallel.ColumnParallelLinear,
module_args_pt=[args.dim, args.dim, False, lambda x: x],
module_args_pt=[
args.dim,
args.dim // args.n_heads * args.n_heads,
False,
lambda x: x,
],
module_name_prefix="layers.0.attention.wq",
)


Expand Down Expand Up @@ -286,6 +319,7 @@ def test_attention(
module_class_pt=model_pt.Attention,
module_args_pt=[args],
ark_inputs=ark_inputs,
module_name_prefix="layers.0.attention",
)


Expand Down Expand Up @@ -326,6 +360,7 @@ def test_transformer_block(
module_class_pt=model_pt.TransformerBlock,
module_args_pt=[0, args],
ark_inputs=ark_inputs,
module_name_prefix="layers.0",
)


Expand Down Expand Up @@ -402,6 +437,9 @@ def test(args, batch_size, seq_len, dtype, world_size):
# Default from HuggingFace
args.vocab_size = 32000

# For debugging
# args.n_layers = 8

# Verify the configurations
assert batch_size <= args.max_batch_size
assert seq_len <= args.max_seq_len
Expand Down
44 changes: 28 additions & 16 deletions python/ark/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,43 @@ def register_parameter(self, name: str, param: Tensor) -> None:
raise TypeError("param must be a Tensor")
self.parameters[name] = param

def load_state_dict(self, state_dict, prefix=""):
def params_dict(self, prefix=""):
params_dict = {}
for name, module in self.sub_modules.items():
if module is not None:
params_dict.update(
module.params_dict(prefix=prefix + name + ".")
)
for name, param in self.parameters.items():
params_dict[prefix + name] = param
return params_dict

def load_state_dict(
self, state_dict: Dict[str, np.ndarray], prefix: str = ""
):
"""
Loads a model from a state_dict and copy the parameters to the device GPU.
Must be called after the executor is launched.
"""
logging.info("Loading model from state_dict")
for name, module in self.sub_modules.items():
if module is not None:
module.load_state_dict(state_dict, prefix=prefix + name + ".")
for name, param in self.parameters.items():
param.from_numpy(state_dict[prefix + name])

def state_dict(self, prefix="") -> Dict[str, np.ndarray]:
all_keys = set(state_dict.keys())
pd = self.params_dict(prefix)
for name, param in pd.items():
param.from_numpy(state_dict[name])
all_keys.remove(name)
if all_keys:
logging.warning(
f"{len(all_keys)} unused parameter(s) in state_dict"
)

def state_dict(self, prefix: str = "") -> Dict[str, np.ndarray]:
"""
Copies the parameters from the device GPU to the host and saves the model to a state_dict.
Copies the parameters from the device GPU to the host and saves the
model to a state_dict.
Must be called after the executor is launched.
"""
state_dict = {}
for name, module in self.sub_modules.items():
if module is not None:
state_dict.update(module.state_dict(prefix=prefix + name + "."))
for name, param in self.parameters.items():
param_np = param.to_numpy()
state_dict[prefix + name] = param_np
return state_dict
return {k: v.to_numpy() for k, v in self.params_dict(prefix).items()}

def forward(self, *args: Any, **kwargs: Any) -> Any:
...
Expand Down

0 comments on commit bdb6cfb

Please sign in to comment.