From 6439bf6f852a476b443a4f6595e062fe3c7f2bd6 Mon Sep 17 00:00:00 2001 From: Yusuke Oda Date: Fri, 20 Dec 2024 05:04:01 +0900 Subject: [PATCH] Add --save-iter-patterns option (#17) * add --save-iter-patterns option * Update megatron/training/training.py Co-authored-by: Kouta Nakayama --------- Co-authored-by: Kouta Nakayama --- megatron/training/arguments.py | 3 +++ megatron/training/training.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 51f1c794..e8bb3c63 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1690,5 +1690,8 @@ def _add_experimental_args(parser): help = 'Config file to add additional arguments') group.add_argument('--force-stop-iter', type=int, default=None, help="Stop training process at this iteration regardless of any other configs.") + group.add_argument('--save-iter-patterns', type=str, default=None, nargs='*', + help='List of regex patterns of step numbers (non-0-padding integer) to save checkpoint. ' + 'E.g., "123.." will save all checkpoints between 12300 and 12399.') return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index 1fe791cf..fa8372dc 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -8,6 +8,7 @@ import math import logging import os +import re import sys from .log_handler import CustomHandler # Make default logging level INFO, but filter out all log messages not from MCore. @@ -205,6 +206,16 @@ def is_save_iteration(iteration: int) -> bool: Returns: True if we should save checkpoint, False otherwise. """ + args = get_args() + + if args.save_iter_patterns is not None: + save_iter_patterns = [re.compile(p) for p in args.save_iter_patterns] + iter_str = str(iteration) + for p in save_iter_patterns: + if p.fullmatch(iter_str): + return True + + if iteration < 10: # 0, 1, ..., 9 return True