diff --git a/.gitignore b/.gitignore index 58489ffb..ce6b78e1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ build *.pth *.pt *.onnx +*.mnn test/temp/* !test/temp/.gitkeep diff --git a/README.md b/README.md index 51538466..7249448b 100644 --- a/README.md +++ b/README.md @@ -116,8 +116,9 @@ python scripts/download_model.py **方式二:直接下载** 存到项目的`hivision/creator/weights`目录下: -- `modnet_photographic_portrait_matting.onnx` (24.7MB): [MODNet](https://github.com/ZHKKKe/MODNet)官方权重,[下载](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR) -- `hivision_modnet.onnx` (24.7MB):对纯色换底适配性更好的抠图模型,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/tag/pretrained-model) +- `modnet_photographic_portrait_matting.onnx` (24.7MB): [MODNet](https://github.com/ZHKKKe/MODNet)官方权重,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/modnet_photographic_portrait_matting.onnx) +- `hivision_modnet.onnx` (24.7MB): 对纯色换底适配性更好的抠图模型,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/hivision_modnet.onnx) +- `mnn_hivision_modnet.mnn` (24.7MB): mnn转换后的抠图模型 by [zjkhahah](https://github.com/zjkhahah),[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/mnn_hivision_modnet.mnn) ## 4. 人脸检测模型配置 diff --git a/app.py b/app.py index 7ed02172..b25746cd 100644 --- a/app.py +++ b/app.py @@ -152,6 +152,8 @@ def idphoto_inference( creator = IDCreator() if matting_model_option == "modnet_photographic_portrait_matting": creator.matting_handler = extract_human_modnet_photographic_portrait_matting + elif matting_model_option == "mnn_hivision_modnet": + creator.matting_handler = extract_human_mnn_modnet else: creator.matting_handler = extract_human @@ -315,7 +317,7 @@ def idphoto_inference( matting_model_list = [ os.path.splitext(file)[0] for file in os.listdir(os.path.join(root_dir, "hivision/creator/weights")) - if file.endswith(".onnx") + if file.endswith(".onnx") or file.endswith(".mnn") ] DEFAULT_MATTING_MODEL = "modnet_photographic_portrait_matting" if DEFAULT_MATTING_MODEL in matting_model_list: diff --git a/demo/images/test5.jpg b/demo/images/test5.jpg deleted file mode 100644 index 4c295450..00000000 Binary files a/demo/images/test5.jpg and /dev/null differ diff --git a/hivision/creator/human_matting.py b/hivision/creator/human_matting.py index fdad0ba1..0031e431 100644 --- a/hivision/creator/human_matting.py +++ b/hivision/creator/human_matting.py @@ -29,7 +29,7 @@ os.path.dirname(__file__), "weights", "mnn_hivision_modnet.mnn", - ) + ), } @@ -50,18 +50,22 @@ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512): import MNN.expr as expr import MNN.nn as nn except ImportError as e: - raise ImportError("MNN模块未安装或导入错误。请确保已安装MNN库,使用命令 'pip install mnn' 安装。") from e + raise ImportError( + "The MNN module is not installed or there was an import error. Please ensure that the MNN library is installed by using the command 'pip install mnn'." + ) from e config = {} - config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理 - config['backend'] = 0 # CPU - config['numThread'] = 4 # 线程数 + config["precision"] = "low" # 当硬件支持(armv8.2)时使用fp16推理 + config["backend"] = 0 # CPU + config["numThread"] = 4 # 线程数 im, width, length = read_modnet_image(input_image, ref_size=512) rt = nn.create_runtime_manager((config,)) - net = nn.load_module_from_file(checkpoint_path, ['input1'], ['output1'], runtime_manager=rt) + net = nn.load_module_from_file( + checkpoint_path, ["input1"], ["output1"], runtime_manager=rt + ) input_var = expr.convert(im, expr.NCHW) output_var = net.forward(input_var) matte = expr.convert(output_var, expr.NCHW) - matte = matte.read()#var转换为np + matte = matte.read() # var转换为np matte = (matte * 255).astype("uint8") matte = np.squeeze(matte) mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) @@ -71,6 +75,7 @@ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512): return output_image + def extract_human_modnet_photographic_portrait_matting(ctx: Context): """ 人像抠图 @@ -84,8 +89,11 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context): ctx.processing_image = hollow_out_fix(matting_image) ctx.matting_image = ctx.processing_image.copy() + def extract_human_mnn_modnet(ctx: Context): - matting_image = get_mnn_modnet_matting(ctx.processing_image, WEIGHTS["mnn_hivision_modnet"]) + matting_image = get_mnn_modnet_matting( + ctx.processing_image, WEIGHTS["mnn_hivision_modnet"] + ) ctx.processing_image = hollow_out_fix(matting_image) ctx.matting_image = ctx.processing_image.copy()