Skip to content

Commit

Permalink
implement training for HPU
Browse files Browse the repository at this point in the history
HPU cards (Gaudi 2 and 3) can't use Accelerate code path. This
contribution adds the training setup and loop for FSDP-only training.

Minor modifications required for HPUs specifically.

Signed-off-by: James Kunstle <[email protected]>
  • Loading branch information
JamesKunstle committed Nov 12, 2024
1 parent 4bcb77e commit f9192d1
Show file tree
Hide file tree
Showing 3 changed files with 476 additions and 15 deletions.
9 changes: 9 additions & 0 deletions src/instructlab/training/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SUPPORTED_MODEL_ARCHITECTURES: list[str] = [
"MistralForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
"GraniteForCausalLM",
]
53 changes: 38 additions & 15 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TorchrunArgs,
TrainingArgs,
)
from instructlab.training.constants import SUPPORTED_MODEL_ARCHITECTURES
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
Expand Down Expand Up @@ -170,15 +171,9 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
)
model.config.eos_token_id = tokenizer.eos_token_id

assert model.__class__.__name__ in [
"MistralForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
"GraniteForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."
assert (
model.__class__.__name__ in SUPPORTED_MODEL_ARCHITECTURES
), f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)
Expand Down Expand Up @@ -683,7 +678,7 @@ def calculate_samples_per_gpu(
return samples_per_gpu_step


def init_distributed_training(local_rank: int, world_size: int, hpu: bool = False):
def init_distributed_training(local_rank: int, world_size: int):
torch.cuda.set_device(LOCAL_RANK)
torch.distributed.init_process_group("nccl")

Expand All @@ -699,6 +694,8 @@ def check_hpu_compatible(
"""
Using flash-attention (and by consequence, Dolomite models) is not supported
if trying to train with Gaudi 2/3 cards.
Raises: RuntimeError
"""

if using_hpu and any([using_dolomite, using_flash_attention]):
Expand All @@ -707,6 +704,27 @@ def check_hpu_compatible(
)


def train_hpu(
args,
model,
tokenizer,
data_loader,
grad_accum_steps,
metric_logger,
):
# First Party
from instructlab.training.train_hpu_fsdp import main as main_hpu

main_hpu(
args=args,
model_name_or_path=model,
tokenizer=tokenizer,
grad_accum_steps=grad_accum_steps,
metric_logger=metric_logger,
data_loader=data_loader,
)


def main(args):
"""
Distributed training setup and execution.
Expand All @@ -730,9 +748,9 @@ def main(args):
using_dolomite=args.use_dolomite,
)

init_distributed_training(
local_rank=LOCAL_RANK, world_size=WORLD_SIZE, hpu=args.hpu
)
if not args.hpu:
# HPU will do its own initialization later.
init_distributed_training(local_rank=LOCAL_RANK, world_size=WORLD_SIZE)

flash_enabled = check_flash_attn_enabled(
disable_flash_attn=args.disable_flash_attn, use_dolomite=args.use_dolomite
Expand Down Expand Up @@ -815,8 +833,13 @@ def main(args):
)

else:
raise NotImplementedError(
"Training on Intel Gaudi 2/3 cards is not supported yet."
train_hpu(
args=args,
model=model,
tokenizer=tokenizer,
data_loader=train_loader,
grad_accum_steps=grad_accum,
metric_logger=metric_logger,
)

torch.distributed.barrier()
Expand Down
Loading

0 comments on commit f9192d1

Please sign in to comment.