diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index f7d93ad..62c2c16 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -19,8 +19,11 @@ def rewrite_load(): """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU memory pressure of loading multiple copies of data under multiple processes""" source = inspect.getsource(transformers.modeling_utils) - modified = re.sub(r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", - r"torch.load(\g<1>, mmap=True)", source) + modified = re.sub( + r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", + r"torch.load(\g<1>, mmap=True)" + if version.parse(transformers.__version__) <= version.parse('4.36') + else r"torch.load(\g<1> mmap=True)", source) modified = re.sub(r"partial\(torch.load,(?![^)]*mmap[^)]*\))([^)]*)\)", r"partial(torch.load,\g<1>, mmap=True)", modified) if (int(os.environ.get("LOCAL_RANK", 0)) == 0):