diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 51e9752e..d47e0416 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -203,3 +203,6 @@ class TrainingArgs(BaseModel): # This field defines whether or not data processing will occur inside of `run_training()` process_data: Optional[bool] = True + + # switch to use Intel Gaudi 2/3 training code if True. + hpu: bool = False diff --git a/src/instructlab/training/constants.py b/src/instructlab/training/constants.py new file mode 100644 index 00000000..fc3bc2e8 --- /dev/null +++ b/src/instructlab/training/constants.py @@ -0,0 +1,9 @@ +SUPPORTED_MODEL_ARCHITECTURES: list[str] = [ + "MistralForCausalLM", + "GPTDolomiteForCausalLM", + "LlamaForCausalLM", + "Starcoder2ForCausalLM", + "GemmaForCausalLM", + "MixtralForCausalLM", + "GraniteForCausalLM", +] diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4c04da0f..19e31d93 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -10,6 +10,7 @@ import re import subprocess import time +import typing # Third Party from accelerate import Accelerator @@ -43,9 +44,10 @@ from transformers import AutoModelForCausalLM, get_scheduler import torch import torch.distributed +import yaml # First Party -from instructlab.training import config +from instructlab.training import config, data_process from instructlab.training.async_logger import AsyncStructuredLogger # pylint: disable=no-name-in-module @@ -55,11 +57,17 @@ TorchrunArgs, TrainingArgs, ) +from instructlab.training.constants import SUPPORTED_MODEL_ARCHITECTURES from instructlab.training.multipack_sampler import ( find_packing_max_batch_len_and_grad_accum, ) from instructlab.training.setup_accelerator import setup_accelerator -from instructlab.training.token_dataset import setup_dataloader, setup_dataset +from instructlab.training.token_dataset import ( + MockDataset, + TokenDataset, + setup_dataloader, + setup_dataset, +) from instructlab.training.tokenizer_utils import setup_tokenizer from instructlab.training.utils import ( StreamablePopen, @@ -79,7 +87,12 @@ set_random_seed, setup_logger, ) -import instructlab.training.data_process as dp + +# GLOBAL VARIABLES FROM ENVIRONMENT +# Will emit a key error if these aren't available. +RANK = int(os.environ["RANK"]) +LOCAL_RANK = int(os.environ["LOCAL_RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) def setup_optimizer(args, model): @@ -180,15 +193,9 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): ) model.config.eos_token_id = tokenizer.eos_token_id - assert model.__class__.__name__ in [ - "MistralForCausalLM", - "GPTDolomiteForCausalLM", - "LlamaForCausalLM", - "Starcoder2ForCausalLM", - "GemmaForCausalLM", - "MixtralForCausalLM", - "GraniteForCausalLM", - ], f"Model class name: {model.__class__.__name__} is not supported." + assert ( + model.__class__.__name__ in SUPPORTED_MODEL_ARCHITECTURES + ), f"Model class name: {model.__class__.__name__} is not supported." model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite) model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha) @@ -403,18 +410,22 @@ def train( if local_rank == 0: inner_pb.update(1) continue + start = time.time() num_loss_counted_tokens = float( torch.tensor([batch.pop("num_loss_counted_tokens")]) ) micro_batch_size = float(torch.tensor([batch.pop("num_samples")])) + if not args.use_dolomite: for k in batch: batch[k] = batch[k].to(local_rank) + output = model( **batch, use_cache=False, ) + loss = output.loss log_loss = loss.detach().item() @@ -536,9 +547,33 @@ def train( ) -def main(args): - # Third Party - import yaml +def configure_tokenizer(chat_template_path: str, model_name_or_path: str): + """ + Loads tokenizer for input model. Replaces chat template and special tokens with + those from repo's chat template and special tokens. + """ + CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(chat_template_path) + tokenizer = setup_tokenizer(model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE) + return tokenizer + + +def read_model_type_from_config(model_name_or_path: str) -> str: + """ + Reads 'model_type' value from model configuration. + Config file named `config.json` by convention. + """ + with open(os.path.join(model_name_or_path, "config.json")) as conf_json: + model_conf = json.load(conf_json) + + return model_conf["model_type"] + + +def setup_metric_logger( + args, log_output_dir: str, rank: int, local_rank: int +) -> AsyncStructuredLogger: + """ + Instantiates AsyncStructuredLogger, prints and logs args on rank=0 process. + """ if args.distributed_training_framework == "deepspeed" and not FusedAdam: raise ImportError( @@ -555,101 +590,255 @@ def main(args): ) metric_logger = AsyncStructuredLogger( - args.output_dir - + f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl" + file_name=os.path.join( + log_output_dir, f"training_params_and_metrics_global{rank}.jsonl" + ) ) - if os.environ["LOCAL_RANK"] == "0": + + if local_rank == 0: print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") metric_logger.log_sync({"script_params": vars(args)}) - setup_logger(args.log_level) - CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path) - tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE) - # device = torch.device("cuda", args.local_rank) + return metric_logger - with open(Path(args.model_name_or_path) / "config.json") as conf_json: - model_conf = json.load(conf_json) - args.model_type = model_conf["model_type"] - #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - args.local_rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group("nccl") - args.global_rank = torch.distributed.get_rank() - tensor = torch.ByteTensor([False]).cuda() - torch.distributed.all_reduce(tensor) - torch.distributed.barrier() +def calculate_percard_packing_params( + local_rank: int, + world_size: int, + dataset: TokenDataset | MockDataset, + effective_batch_size: int, + max_batch_len: int, + is_padding: bool, + seed: int, +) -> typing.Tuple[int, int, str]: + """ + Given the effective_batch_size (number of samples per batch required) and + the max_batch_len (maximum number of tokens allowed in a batch), calculate + the number of gradient accumulation steps required to satisfy the effective_batch_size + given batches of size "packing_max_batch_len." + + It may be the case that there are too few samples to correctly distribute the data + evenly across the available cards. If this is the case, gradient accumulation (and subsequent + optimizations that we can make) are ignored and we default to the `distributed data sampler`. + """ - flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite) - - dataset = setup_dataset( - args.data_path, - mock=args.mock_data, - mock_len=args.mock_len, - ) + # will try to make multipack work if possible. + sampler_type: str = "multipack" try: packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( - num_gpus=torch.distributed.get_world_size(), + num_gpus=world_size, avg_sample_len=dataset.get_lengths().mean(), - effective_batch_size=args.effective_batch_size, - max_batch_len_per_gpu=args.max_batch_len, - is_padding=not (args.use_dolomite or flash_enabled), + effective_batch_size=effective_batch_size, + max_batch_len_per_gpu=max_batch_len, + is_padding=is_padding, dataset=dataset, - seed=args.seed, + seed=seed, ) - args.sampler = "multipack" except RuntimeError as e: - if os.environ["LOCAL_RANK"] == "0": + if local_rank == 0: print(f"\033[38;5;120m{e}\033[0m") # fallback to grad accum = 1 # NOTE: packing max batch len will not be used packing_max_batch_len = None grad_accum = 1 - args.sampler = "distributed" + sampler_type = "distributed" - args.samples_per_gpu = ( - args.effective_batch_size // grad_accum // torch.distributed.get_world_size() - ) + return packing_max_batch_len, grad_accum, sampler_type - train_loader = setup_dataloader( - dataset, - tokenizer.pad_token_id, + +def setup_data_loader_with_fallback( + dataset: TokenDataset | MockDataset, + tokenizer, + sampler, + use_dolomite: bool, + flash_enabled: bool, + max_batch_len: int, + packing_max_batch_len: int, + samples_per_gpu: int, + seed, +) -> typing.Tuple[DataLoader | MockDataset, str]: + """ + + + this happens sometimes when we have more GPUs than data to process. In this case + we should either alert the user to switch samplers, or do it automatically and + warn them about it happening + """ + data_loader = setup_dataloader( + dataset=dataset, + pad_token_id=tokenizer.pad_token_id, num_workers=8, - use_dolomite=args.use_dolomite, + use_dolomite=use_dolomite, flash_enabled=flash_enabled, - max_batch_len=args.max_batch_len, + max_batch_len=max_batch_len, packing_max_batch_len=packing_max_batch_len, - samples_per_gpu=args.samples_per_gpu, - sampler=args.sampler, - seed=args.seed, + samples_per_gpu=samples_per_gpu, + sampler=sampler, + seed=seed, ) - if len(train_loader) == 0: - # this happens sometimes when we have more GPUs than data to process. In this case - # we should either alert the user to switch samplers, or do it automatically and - # warn them about it happening + + if len(data_loader) == 0: print( "\033[93mThe dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!\033[0m" ) - args.sampler = "distributed" - train_loader = setup_dataloader( - dataset, - tokenizer.pad_token_id, + sampler = "distributed" + data_loader = setup_dataloader( + dataset=dataset, + pad_token_id=tokenizer.pad_token_id, num_workers=8, - use_dolomite=args.use_dolomite, + use_dolomite=use_dolomite, flash_enabled=flash_enabled, - max_batch_len=args.max_batch_len, + max_batch_len=max_batch_len, packing_max_batch_len=packing_max_batch_len, - samples_per_gpu=args.samples_per_gpu, - sampler=args.sampler, - seed=args.seed, + samples_per_gpu=samples_per_gpu, + sampler=sampler, + seed=seed, ) - if args.local_rank == 0: + return data_loader, sampler + + +def calculate_samples_per_gpu( + world_size: int, effective_batch_size: int, grad_accum_steps: int +) -> int: + """ + Given the effective_batch_size (total batch size across all GPUs), the world_size (number of participating GPUs), + and grad_accum_steps (number of gradient-producing forward passes before a .backward() call), calculate + the number of samples per card, per batch of training. + """ + + samples_per_grad_accum_step = effective_batch_size // grad_accum_steps + samples_per_gpu_step = samples_per_grad_accum_step // world_size + + return samples_per_gpu_step + + +def init_distributed_training(local_rank: int, world_size: int): + torch.cuda.set_device(LOCAL_RANK) + torch.distributed.init_process_group("nccl") + + # check that communication works between all participating cards + tensor = torch.ByteTensor([False]).cuda() + torch.distributed.all_reduce(tensor) + torch.distributed.barrier() + + +def check_hpu_compatible( + using_hpu: bool, using_flash_attention: bool, using_dolomite: bool +) -> None: + """ + Using flash-attention (and by consequence, Dolomite models) is not supported + if trying to train with Gaudi 2/3 cards. + + Raises: RuntimeError + """ + + if using_hpu and any([using_dolomite, using_flash_attention]): + raise RuntimeError( + "Attempting to train with Gaudi HPUs with unsupported Dolomite or Flash Attention." + ) + + +def train_hpu( + args, + model, + tokenizer, + data_loader, + grad_accum_steps, + metric_logger, +): + # First Party + from instructlab.training.train_hpu_fsdp import main as main_hpu + + main_hpu( + args=args, + model_name_or_path=model, + tokenizer=tokenizer, + grad_accum_steps=grad_accum_steps, + metric_logger=metric_logger, + data_loader=data_loader, + ) + + +def main(args): + """ + Distributed training setup and execution. + """ + + metric_logger = setup_metric_logger( + args, log_output_dir=args.output_dir, rank=RANK, local_rank=LOCAL_RANK + ) + setup_logger(args.log_level) + + args.model_type = read_model_type_from_config( + model_name_or_path=args.model_name_or_path + ) + args.local_rank = LOCAL_RANK + args.global_rank = RANK + + if args.hpu: + check_hpu_compatible( + using_hpu=args.hpu, + using_flash_attention=not args.disable_flash_attn, + using_dolomite=args.use_dolomite, + ) + + if not args.hpu: + # HPU will do its own initialization later. + init_distributed_training(local_rank=LOCAL_RANK, world_size=WORLD_SIZE) + + flash_enabled = check_flash_attn_enabled( + disable_flash_attn=args.disable_flash_attn, use_dolomite=args.use_dolomite + ) + + tokenizer = configure_tokenizer( + chat_template_path=args.chat_tmpl_path, + model_name_or_path=args.model_name_or_path, + ) + + dataset = setup_dataset( + data_path=args.data_path, + mock=args.mock_data, + mock_len=args.mock_len, + ) + + using_padding = not (args.use_dolomite or flash_enabled) + packing_max_batch_len, grad_accum, sampler_guess = calculate_percard_packing_params( + local_rank=LOCAL_RANK, + world_size=WORLD_SIZE, + dataset=dataset, + effective_batch_size=args.effective_batch_size, + max_batch_len=args.max_batch_len, + is_padding=using_padding, + seed=args.seed, + ) + args.sampler = sampler_guess + + args.samples_per_gpu = calculate_samples_per_gpu( + world_size=WORLD_SIZE, + effective_batch_size=args.effective_batch_size, + grad_accum_steps=grad_accum, + ) + + train_loader, confirmed_sampler = setup_data_loader_with_fallback( + dataset=dataset, + tokenizer=tokenizer, + sampler=args.sampler, + use_dolomite=args.use_dolomite, + flash_enabled=flash_enabled, + max_batch_len=args.max_batch_len, + packing_max_batch_len=packing_max_batch_len, + samples_per_gpu=args.sampler_per_gpu, + seed=args.seed, + ) + args.sampler = confirmed_sampler + + if LOCAL_RANK == 0: metric_logger.log_sync( { - "num_gpus": torch.distributed.get_world_size(), + "num_gpus": WORLD_SIZE, "avg_sample_len": dataset.get_lengths().mean(), "effective_batch_size": args.effective_batch_size, "max_batch_len_per_gpu": args.max_batch_len, @@ -662,23 +851,34 @@ def main(args): } ) - model, lr_scheduler, optimizer, accelerator = setup_model( - args, tokenizer, train_loader, grad_accum, flash_enabled - ) + if not args.hpu: + model, lr_scheduler, optimizer, accelerator = setup_model( + args, tokenizer, train_loader, grad_accum, flash_enabled + ) - load_latest_full_state(args=args, accelerator=accelerator) + load_latest_full_state(args=args, accelerator=accelerator) - train( - args, - model, - optimizer, - lr_scheduler, - accelerator, - tokenizer, - train_loader, - grad_accum, - metric_logger, - ) + train( + args, + model, + optimizer, + lr_scheduler, + accelerator, + tokenizer, + train_loader, + grad_accum, + metric_logger, + ) + + else: + train_hpu( + args=args, + model=model, + tokenizer=tokenizer, + data_loader=train_loader, + grad_accum_steps=grad_accum, + metric_logger=metric_logger, + ) torch.distributed.barrier() torch.distributed.destroy_process_group() @@ -692,7 +892,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: check_valid_train_args(train_args) if train_args.process_data: - dp.main( + data_process.main( DataProcessArgs( # XXX(osilkin): make a decision here, either: # 1. the CLI is fully responsible for managing where the data is written @@ -810,6 +1010,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}" ) + if train_args.hpu: + command.append("--hpu") + print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m") process = None interrupt: KeyboardInterrupt | Exception | None = None @@ -985,6 +1188,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) parser.add_argument("--disable_flash_attn", action="store_true") + + parser.add_argument( + "--hpu", + action="store_true", + help="If set, uses specialized code path that implements training for Intel Gaudi 2/3 cards. Not compatible with Nvidia or AMD training.", + ) + args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/tokenizer_utils.py b/src/instructlab/training/tokenizer_utils.py index d6c55e7e..a79822cd 100644 --- a/src/instructlab/training/tokenizer_utils.py +++ b/src/instructlab/training/tokenizer_utils.py @@ -14,6 +14,7 @@ def setup_tokenizer( if not SPECIAL_TOKENS.pad.token: SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos + tokenizer.add_special_tokens( { "bos_token": SPECIAL_TOKENS.bos.token, @@ -21,9 +22,11 @@ def setup_tokenizer( "pad_token": SPECIAL_TOKENS.pad.token, } ) + tokenizer.add_special_tokens( {"additional_special_tokens": SPECIAL_TOKENS.get_tokens_to_add()} ) + if getattr(tokenizer, "add_bos_token", False) or getattr( tokenizer, "add_eos_token", False ): diff --git a/src/instructlab/training/train_hpu_fsdp.py b/src/instructlab/training/train_hpu_fsdp.py new file mode 100644 index 00000000..2466e2bd --- /dev/null +++ b/src/instructlab/training/train_hpu_fsdp.py @@ -0,0 +1,429 @@ +# Forcing Intel PyTorch bridge Eager mode. + +# Standard +import contextlib +import functools +import math + +# Standard Library +import os +import time + + +# Third Party +os.environ["PT_HPU_LAZY_MODE"] = "0" +import habana_frameworks.torch as htorch +import habana_frameworks.torch.distributed.hccl + +from torch.distributed import ReduceOp, all_reduce +from torch.distributed.fsdp import BackwardPrefetch, CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from tqdm import tqdm +import tokenizers +import torch +import transformers + +# First Party +from instructlab.training import constants, utils +from instructlab.training.config import DistributedBackend +from instructlab.training.utils import add_noisy_embeddings, convert_loss_to_reduce_sum + +# Constants + +# Will emit a key error if these aren't available. +RANK = int(os.environ["RANK"]) +LOCAL_RANK = int(os.environ["LOCAL_RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +DEVICE_HPU = torch.device("hpu") + + +def _setup_hpu_torch_distributed(): + """ + Initialized distributed process group. + + Raises: RuntimeError if initialization fails. + """ + + torch.distributed.init_process_group( + backend="hccl", rank=LOCAL_RANK, world_size=WORLD_SIZE + ) + + if not torch.distributed.is_initialized(): + raise RuntimeError( + f"Attempted to initialize torch distributed process group for HPU but failed." + ) + + +def setup_fsdp(model: torch.nn.Module, sharding_strategy: str, cpu_param_offload: bool): + """Wraps model in FSDP class.""" + + block_name = model._no_split_modules[0] + transformer_attention_block_class: torch.nn.Module | None = ( + utils.get_module_class_from_name(model, block_name) + ) + + if transformer_attention_block_class is None: + raise RuntimeError( + f"Transformer block class cannot be derived from transformer module. Cannot correctly wrap block: ({transformer_attention_block_class})" + ) + + model = FSDP( + module=model, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={type(transformer_attention_block_class)}, + ), + limit_all_gathers=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + sharding_strategy=ShardingStrategy[sharding_strategy], + device_id=torch.device("hpu", torch.hpu.current_device()), + cpu_offload=CPUOffload(offload_params=cpu_param_offload), + ) + + return model + + +def setup_optimizer(model: torch.nn.Module, learning_rate: float) -> torch.optim.AdamW: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + betas=(0.9, 0.95), + weight_decay=0.0, + ) + + return optimizer + + +def try_load_checkpoint(*args, **kwargs): + raise NotImplementedError() + + +def save_checkpoint(*args, **kwargs): + # save_checkpoint(model, optimizer, lr_scheduler, other_state: dict) + + raise NotImplementedError() + + +def _set_sampler_epoch(sampler_type: str, data_loader, epoch: int): + if sampler_type == "multipack": + data_loader.batch_sampler.set_epoch(epoch) + elif sampler_type == "distributed": + data_loader.sampler.set_epoch(epoch) + else: + raise RuntimeError(f"Sampler type ({sampler_type}) is not supported.") + + +def print_status(loss, num_loss_counted_tokens, global_step, epoch): + print( + f"\033[93mPer-token loss scaled by world size: {(loss/num_loss_counted_tokens) * WORLD_SIZE}\033[0m" + ) + print(f"Epoch: {epoch}, Step: {global_step}, Rank: {RANK}, loss = {loss}") + + +def batch_metric_log( + args, + metric_logger, + epoch, + global_step, + loss, + reduced_loss, + num_loss_counted_tokens, + current_lr, + grad_norm, + samples_seen, + start_time, + last_batch_size, +): + if LOCAL_RANK != 0: + return + + elapsed_time = time.time() - start_time + overall_throughput = args.samples_per_gpu * WORLD_SIZE / elapsed_time + # vmem_allocated = htorch.memory_allocated() / (1024**3) + # vmalloc_retries = htorch.memory_stats()["num_alloc_retries"] + # global_grad_norm = model.get_global_grad_norm() + metric_logger.log_sync( + { + "epoch": epoch, + "step": global_step, + "rank": LOCAL_RANK, + "loss": loss.item(), + "overall_throughput": overall_throughput, + "lr": current_lr, + # "vmem_allocated": vmem_allocated, + # "vmalloc_retries": vmalloc_retries, + # "num_loss_counted_tokens": int(num_loss_counted_tokens), + "batch_size": last_batch_size, + "total_loss": float(reduced_loss / num_loss_counted_tokens), + "gradnorm": grad_norm, + "weight_norm": 0.0, + } + ) + + +def train( + args, + model: torch.nn.Module, + optimizer: torch.optim.AdamW, + data_loader: torch.utils.data.DataLoader, + lr_scheduler, + grad_accum_steps: int, + num_epochs: int, + metric_logger, +): + model.train() + optimizer.zero_grad() + global_step = 1 + global_grad_norm = None + samples_seen = 0 + batch_size = args.effective_batch_size // grad_accum_steps + args.save_samples = (args.save_samples // batch_size) * batch_size + + if LOCAL_RANK == 0: + print(f"\033[93mNumber of samples per save: {args.save_samples}\033[0m") + + # (jkunstle) TODO: implement current_epoch + for epoch in range(num_epochs): + _set_sampler_epoch( + sampler_type=args.sampler, data_loader=data_loader, epoch=epoch + ) + + if LOCAL_RANK == 0: + progress_bar = tqdm(total=len(data_loader), desc=f"Epoch {epoch}") + if args.last_step: + progress_bar.update(args.last_step) + + dist_shared_buffer = torch.zeros(3, dtype=torch.float32).to(DEVICE_HPU) + + for batch in data_loader: + start_time = time.time() + dist_shared_buffer[0] = batch.pop("num_loss_counted_tokens") + dist_shared_buffer[1] = len(batch["input_ids"]) + + # batch = {input_ids: ..., labels: ..., attention_mask: ...}, + # each is a torch.Tensor. + for k in batch: + batch[k] = batch[k].to(DEVICE_HPU) + + no_sync = contextlib.nullcontext + if global_step % grad_accum_steps != 0: + no_sync = model.no_sync + + with no_sync(): + output = model(**batch, use_cache=False) + loss = output.loss + + dist_shared_buffer[2] = loss.item() + + all_reduce(tensor=dist_shared_buffer, op=ReduceOp.SUM) + + # These have been summed over all participating cards. + num_loss_counted_tokens = dist_shared_buffer[0] + samples_seen += int(dist_shared_buffer[1]) + + # (jkunstle) TODO: make sure this is correct for FSDP, was originally for DeepSpeed + # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when FSDP averages by world_size, it will be the correct loss. + loss = loss / num_loss_counted_tokens * WORLD_SIZE + + print_status( + loss=loss, + num_loss_counted_tokens=num_loss_counted_tokens, + global_step=global_step, + epoch=epoch, + ) + + loss.backward() + + if global_step % grad_accum_steps == 0: + global_grad_norm = model.clip_grad_norm_(1.0) + # global_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + global_grad_norm = ( + float(global_grad_norm) if global_grad_norm is not None else None + ) + batch_metric_log( + args=args, + metric_logger=metric_logger, + epoch=epoch, + global_step=global_step, + loss=loss, + reduced_loss=dist_shared_buffer[2], + num_loss_counted_tokens=num_loss_counted_tokens, + current_lr=lr_scheduler.get_last_lr()[0], + grad_norm=global_grad_norm, + samples_seen=samples_seen, + start_time=start_time, + last_batch_size=int( + dist_shared_buffer[1] + ), # sum(len(input_ids) for all cards) + ) + + global_step += 1 + if LOCAL_RANK == 0: + progress_bar.update(1) + + # (jkunstle) TODO: save checkpoint for save_samples, epochs, and final. + + +def _match_model_and_tokenizer_special_tokens( + model: torch.nn.Module, tokenizer: tokenizers.Tokenizer, token_list: list[str] +) -> torch.nn.Module: + """ + Model might have different representations for special tokens, like eos_token and bos_token. + This function matches a model's tokens to that of the tokenizer. + """ + + for tok in token_list: + model_tok = getattr(model.config, tok, None) + tokenizer_tok = getattr(tokenizer, tok, None) + + if ( + model_tok is not None + and tokenizer_tok is not None + and model_tok != tokenizer_tok + ): + print( + f"WARNING: There is a mismatch between {tok} of model ({model_tok}) and tokenizer({tokenizer_tok}). Fixing model {tok} to be same as tokenizer's {tok}" + ) + + setattr(model.config, tok, tokenizer_tok) + + return model + + +def _match_model_and_tokenizer_vocab_lengths( + model: torch.nn.Module, tokenizer: tokenizers.Tokenizer +) -> torch.nn.Module: + tokenizer_len = len(tokenizer) + if tokenizer_len > model.config.vocab_size: + print( + f"WARNING: tokenizer has {tokenizer_len} tokens but model has {model.config.vocab_size} vocab size. Resizing token embeddings." + ) + + model.resize_token_embeddings( + int(8 * math.ceil(tokenizer_len / 8.0)) + ) # make the vocab size multiple of 8 for sharding the embedding layer. + + return model + + +def prepare_model( + model: torch.nn.Module, tokenizer: tokenizers.Tokenizer, noise_alpha: float +) -> torch.nn.Module: + """ + Modifies model so that it works correctly with tokenizer vocab and special tokens, multipack sampler, + and has gradient checkpointing enabled. + """ + + model = _match_model_and_tokenizer_vocab_lengths(model=model, tokenizer=tokenizer) + model = _match_model_and_tokenizer_special_tokens( + model=model, + tokenizer=tokenizer, + token_list=["bos_token_id", "eos_token_id", "pad_tok_id"], + ) + + model = convert_loss_to_reduce_sum(model, use_dolomite=False) + model = add_noisy_embeddings(model, noise_alpha=noise_alpha) + + model.gradient_checkpointing_enable() + + return model + + +def load_model(model_name_or_path: str) -> torch.nn.Module: + """Load Transformer model and validate that it's among supported models.""" + + # (jkunstle) TODO: could load model config on its own and check for the class type before + # downloading / loading the entire model into memory. + + model = transformers.AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_name_or_path, torch_dtype=torch.bfloat16 + ) + + if model.__class__.__name__ not in constants.SUPPORTED_MODEL_ARCHITECTURES: + raise RuntimeError( + f"Model class name: {model.__class__.__name__} is not supported." + ) + + return model + + +def _raise_exception_for_unsupported_args(args) -> None: + """ + Make sure that user isn't expecting training to be configured for: + 1) LoRA PEFT + 2) Quantization + 3) Distributed backend that's not FSDP + """ + + if args.lora_r > 0: + raise RuntimeError( + f"LoRA rank was set (lora_r={args.lora_r}) but not supported when training with (--hpu)." + ) + + if args.lora_quant_bits is not None: + raise RuntimeError( + f"QLoRA was set (lora_quant_bits={args.lora_quant_bits}) but not supported when training with (--hpu)." + ) + + chosen_backend = DistributedBackend(args.distributed_training_framework) + if chosen_backend != DistributedBackend.FSDP: + raise RuntimeError( + f"Distributed backend was set as (distributed_training_framework={chosen_backend.value}) but only ({DistributedBackend.FSDP.value}) is suppported with (--hpu)." + ) + + +def main( + args, + model_name_or_path: str, + tokenizer: tokenizers.Tokenizer, + data_loader: torch.utils.data.DataLoader, + grad_accum_steps: int, + metric_logger, +): + # (jkunstle) TODO: setup logger for file + + _raise_exception_for_unsupported_args(args) + _setup_hpu_torch_distributed() + + # (jkunstle) TODO: try to load checkpoint + model = load_model(model_name_or_path=model_name_or_path) + model = prepare_model( + model=model, tokenizer=tokenizer, noise_alpha=args.NEFTune_alpha + ) + + model = setup_fsdp( + model=model, + sharding_strategy=args.fsdp_sharding_strategy, + cpu_param_offload=args.cpu_offload_params_fsdp, + ) + + optimizer = setup_optimizer(model=model, learning_rate=args.lr) + + lr_scheduler = transformers.get_scheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.num_epochs * len(data_loader) // grad_accum_steps, + ) + + train( + args=args, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + metric_logger=metric_logger, + data_loader=data_loader, + grad_accum_steps=grad_accum_steps, + num_epochs=args.num_epochs, + ) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index b6f655bf..57c7c619 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -691,6 +691,7 @@ def _whether_to_checkpoint(submodule: torch.nn.Module) -> bool: def setup_logger(level="DEBUG"): + """sets basic rank logging configuration""" logging.basicConfig( level=level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] )