Skip to content

Commit

Permalink
add sharding for gpt3 (#1064)
Browse files Browse the repository at this point in the history
* add sharding for gpt-3

* del debug

* add sharding save model

* update model save

* fix seed func

* set control in tensor parallel

Co-authored-by: Zhong Hui <[email protected]>
  • Loading branch information
zhaoyinglia and ZHUI authored Oct 11, 2021
1 parent e05aed8 commit 91d81c9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 32 deletions.
5 changes: 3 additions & 2 deletions examples/language_model/gpt-3/dygraph/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def process_batch_size(args):
"global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]"\
.format(args.global_batch_size, args.local_batch_size, args.dp_degree)
elif args.global_batch_size is not None and args.local_batch_size is None:
args.local_batch_size = args.global_batch_size // args.dp_degree
args.local_batch_size = args.global_batch_size // (args.dp_degree *
args.sharding_degree)
else:
args.global_batch_size = args.local_batch_size * args.dp_degree
args.global_batch_size = args.local_batch_size * args.dp_degree * args.sharding_degree
assert args.local_batch_size % args.micro_batch_size == 0


Expand Down
1 change: 1 addition & 0 deletions examples/language_model/gpt-3/dygraph/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python -m paddle.distributed.launch --log_dir $log_dir --gpus "0,1,2,3,4,5,6,7"
--dp_degree 2\
--mp_degree 2\
--pp_degree 2\
--sharding_degree 1\
--use_amp True\
--use_recompute False

88 changes: 58 additions & 30 deletions examples/language_model/gpt-3/dygraph/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@
import lr
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer

MODEL_CLASSES = {
"gpt": (GPTForPretraining, GPTTokenizer),
"gpt-cn": (GPTForPretraining, GPTChineseTokenizer),
}


def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank):
def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank):
assert args.device != "cpu"

random.seed(basic_seed + dp_rank)
np.random.seed(basic_seed + dp_rank)
paddle.seed(basic_seed + dp_rank)
random.seed(basic_seed + data_world_rank)
np.random.seed(basic_seed + data_world_rank)
paddle.seed(basic_seed + data_world_rank)

# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
global_seed = basic_seed + dp_rank
global_seed = basic_seed + data_world_rank
tracker = get_rng_state_tracker()
tracker.add('global_seed', global_seed)
tracker.add('local_seed', local_seed)
Expand Down Expand Up @@ -92,14 +93,18 @@ def do_train(args):
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.mp_degree,
"pp_degree": args.pp_degree
"pp_degree": args.pp_degree,
"sharding_degree": args.sharding_degree
}

strategy.pipeline_configs = {
"accumulate_steps": args.local_batch_size // args.micro_batch_size,
"micro_batch_size": args.micro_batch_size
}

# set control in tensor parallel
strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

fleet.init(is_collective=True, strategy=strategy)

# obtain rank message of hybrid parallel
Expand All @@ -108,10 +113,15 @@ def do_train(args):
mp_rank = hcg.get_model_parallel_rank()
pp_rank = hcg.get_stage_id()
dp_rank = hcg.get_data_parallel_rank()
sharding_rank = hcg.get_sharding_parallel_rank()

sharding_size = hcg.get_sharding_parallel_world_size()
data_world_rank = dp_rank * sharding_size + sharding_rank
data_world_size = args.dp_degree * args.sharding_degree
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))

# seed control in hybrid parallel
set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank)
set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank)

default_global_tokens_num = args.global_batch_size * args.max_seq_len

Expand Down Expand Up @@ -183,15 +193,31 @@ def do_train(args):
if not any(nd in n for nd in ["bias", "norm"])
]

optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr,
beta1=args.adam_beta1,
beta2=args.adam_beta2,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in decay_params)
if args.sharding_degree > 1:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=lr_scheduler
if lr_scheduler is not None else args.max_lr,
beta1=args.adam_beta1,
beta2=args.adam_beta2,
epsilon=args.adam_epsilon,
weight_decay=args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in decay_params)
else:
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler
if lr_scheduler is not None else args.max_lr,
beta1=args.adam_beta1,
beta2=args.adam_beta2,
epsilon=args.adam_epsilon,
parameters=model.parameters(),
weight_decay=args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in decay_params)

if paddle.distributed.get_world_size() > 1:
model = fleet.distributed_model(model)
Expand Down Expand Up @@ -227,8 +253,8 @@ def do_train(args):
args,
data_file,
local_rank=local_rank,
data_world_size=args.dp_degree,
data_world_rank=dp_rank,
data_world_size=data_world_size,
data_world_rank=data_world_rank,
eos_id=tokenizer.eos_token_id)
# Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
# many times. and start a new random dataloader.
Expand Down Expand Up @@ -309,6 +335,7 @@ def do_train(args):
args.eval_iters, log_writer, global_step,
epoch, "valid")

# TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
# only dp_rank = 0 save model
if (global_step % args.save_steps == 0 or
global_step >= args.max_steps) and dp_rank == 0:
Expand All @@ -322,24 +349,25 @@ def do_train(args):
logger.info("Save model to %s" % output_dir)

if args.pp_degree > 1:
model_to_save.save_state_dict(output_dir)
if mp_rank * pp_rank == 1:
if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0:
tokenizer.save_pretrained(output_dir)
model_to_save.save_state_dict(output_dir)
paddle.save(
optimizer.state_dict(),
os.path.join(
output_dir,
"model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt".
format(mp_rank, pp_rank)))
"model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt".
format(mp_rank, sharding_rank, pp_rank)))
else:
path = os.path.join(output_dir,
'model_{:0>2d}'.format(mp_rank))
os.makedirs(path, exist_ok=True)
model_to_save.save_pretrained(path)

paddle.save(optimizer.state_dict(),
os.path.join(path, "model_state.pdopt"))
tokenizer.save_pretrained(path)
if mp_rank == 0 and sharding_rank == 0:
tokenizer.save_pretrained(output_dir)
model_to_save.save_pretrained(output_dir)
paddle.save(
optimizer.state_dict(),
os.path.join(
output_dir,
"model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt".
format(mp_rank, sharding_rank)))

if global_step >= args.max_steps:
run_evaluate(args, test_data_loader, model, criterion,
Expand Down

0 comments on commit 91d81c9

Please sign in to comment.