diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 51f1c794..e8bb3c63 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1690,5 +1690,8 @@ def _add_experimental_args(parser): help = 'Config file to add additional arguments') group.add_argument('--force-stop-iter', type=int, default=None, help="Stop training process at this iteration regardless of any other configs.") + group.add_argument('--save-iter-patterns', type=str, default=None, nargs='*', + help='List of regex patterns of step numbers (non-0-padding integer) to save checkpoint. ' + 'E.g., "123.." will save all checkpoints between 12300 and 12399.') return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index 1fe791cf..fa8372dc 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -8,6 +8,7 @@ import math import logging import os +import re import sys from .log_handler import CustomHandler # Make default logging level INFO, but filter out all log messages not from MCore. @@ -205,6 +206,16 @@ def is_save_iteration(iteration: int) -> bool: Returns: True if we should save checkpoint, False otherwise. """ + args = get_args() + + if args.save_iter_patterns is not None: + save_iter_patterns = [re.compile(p) for p in args.save_iter_patterns] + iter_str = str(iteration) + for p in save_iter_patterns: + if p.fullmatch(iter_str): + return True + + if iteration < 10: # 0, 1, ..., 9 return True