From 4640605e2430ee76c5ee4fd71ba2fe9e1239c66a Mon Sep 17 00:00:00 2001 From: luzhan Date: Mon, 23 Sep 2024 14:44:18 +0800 Subject: [PATCH] fix: different replacements for different transformer versions --- flashmodels/patch/patch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index f7d93ad..62c2c16 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -19,8 +19,11 @@ def rewrite_load(): """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU memory pressure of loading multiple copies of data under multiple processes""" source = inspect.getsource(transformers.modeling_utils) - modified = re.sub(r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", - r"torch.load(\g<1>, mmap=True)", source) + modified = re.sub( + r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", + r"torch.load(\g<1>, mmap=True)" + if version.parse(transformers.__version__) <= version.parse('4.36') + else r"torch.load(\g<1> mmap=True)", source) modified = re.sub(r"partial\(torch.load,(?![^)]*mmap[^)]*\))([^)]*)\)", r"partial(torch.load,\g<1>, mmap=True)", modified) if (int(os.environ.get("LOCAL_RANK", 0)) == 0):