-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Document] Add hybrid 2D parallelism example
- Loading branch information
Showing
7 changed files
with
226 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
### Hybrid tensor parallelism and data parallelism training | ||
|
||
Support for hybrid 3D parallelism for 🤗 `transformers` will be available in the upcoming weeks (it's basically done, but it doesn't support 🤗 `transformers` yet) | ||
|
||
`nproc-per-node` is equal to tensor_parallel_size * pipeline_parallel_size * data_parallel_size | ||
|
||
```bash | ||
torchrun --standalone --nnodes=1 --nproc-per-node=4 hybrid_parallelism.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from datasets import load_dataset | ||
from torch.optim import SGD | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.distributed.parallel_mode import ParallelMode | ||
from pipegoose.nn import DataParallel, TensorParallel | ||
|
||
if __name__ == "__main__": | ||
DATA_PARALLEL_SIZE = 2 | ||
TENSOR_PARALLEL_SIZE = 2 | ||
PIPELINE_PARALLEL_SIZE = 1 | ||
BATCH_SIZE = 4 | ||
|
||
parallel_context = ParallelContext.from_torch( | ||
data_parallel_size=DATA_PARALLEL_SIZE, | ||
tensor_parallel_size=TENSOR_PARALLEL_SIZE, | ||
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, | ||
) | ||
rank = parallel_context.get_global_rank() | ||
|
||
dataset = load_dataset("imdb", split="train[:100]") | ||
dataset = dataset.map(lambda x: {"text": x["text"][:30]}) | ||
|
||
dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) | ||
sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=69) | ||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=sampler) | ||
|
||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") | ||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
model = TensorParallel(model, parallel_context).parallelize() | ||
model = DataParallel(model, parallel_context).parallelize() | ||
optim = SGD(model.parameters(), lr=1e-3) | ||
model.to("cuda") | ||
device = next(model.parameters()).device | ||
|
||
print(f"rank={rank}, moved to device: {device}") | ||
|
||
for epoch in range(100): | ||
sampler.set_epoch(epoch) | ||
|
||
for batch in dataloader: | ||
inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=1024, return_tensors="pt") | ||
inputs = {name: tensor.to(device) for name, tensor in inputs.items()} | ||
labels = inputs["input_ids"] | ||
|
||
outputs = model(**inputs, labels=labels) | ||
|
||
optim.zero_grad() | ||
outputs.loss.backward() | ||
optim.step() | ||
|
||
print(f"rank={rank}, loss={outputs.loss}") | ||
|
||
model.cpu() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.