Skip to content

Commit

Permalink
docs: mnn model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin committed Sep 7, 2024
1 parent 9740752 commit ad69622
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ build
*.pth
*.pt
*.onnx
*.mnn
test/temp/*
!test/temp/.gitkeep

Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ python scripts/download_model.py
**方式二:直接下载**

存到项目的`hivision/creator/weights`目录下:
- `modnet_photographic_portrait_matting.onnx` (24.7MB): [MODNet](https://github.com/ZHKKKe/MODNet)官方权重,[下载](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR)
- `hivision_modnet.onnx` (24.7MB):对纯色换底适配性更好的抠图模型,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/tag/pretrained-model)
- `modnet_photographic_portrait_matting.onnx` (24.7MB): [MODNet](https://github.com/ZHKKKe/MODNet)官方权重,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/modnet_photographic_portrait_matting.onnx)
- `hivision_modnet.onnx` (24.7MB): 对纯色换底适配性更好的抠图模型,[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/hivision_modnet.onnx)
- `mnn_hivision_modnet.mnn` (24.7MB): mnn转换后的抠图模型 by [zjkhahah](https://github.com/zjkhahah)[下载](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/download/pretrained-model/mnn_hivision_modnet.mnn)


## 4. 人脸检测模型配置
Expand Down
4 changes: 3 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def idphoto_inference(
creator = IDCreator()
if matting_model_option == "modnet_photographic_portrait_matting":
creator.matting_handler = extract_human_modnet_photographic_portrait_matting
elif matting_model_option == "mnn_hivision_modnet":
creator.matting_handler = extract_human_mnn_modnet
else:
creator.matting_handler = extract_human

Expand Down Expand Up @@ -315,7 +317,7 @@ def idphoto_inference(
matting_model_list = [
os.path.splitext(file)[0]
for file in os.listdir(os.path.join(root_dir, "hivision/creator/weights"))
if file.endswith(".onnx")
if file.endswith(".onnx") or file.endswith(".mnn")
]
DEFAULT_MATTING_MODEL = "modnet_photographic_portrait_matting"
if DEFAULT_MATTING_MODEL in matting_model_list:
Expand Down
Binary file removed demo/images/test5.jpg
Binary file not shown.
24 changes: 16 additions & 8 deletions hivision/creator/human_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
os.path.dirname(__file__),
"weights",
"mnn_hivision_modnet.mnn",
)
),
}


Expand All @@ -50,18 +50,22 @@ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
import MNN.expr as expr
import MNN.nn as nn
except ImportError as e:
raise ImportError("MNN模块未安装或导入错误。请确保已安装MNN库,使用命令 'pip install mnn' 安装。") from e
raise ImportError(
"The MNN module is not installed or there was an import error. Please ensure that the MNN library is installed by using the command 'pip install mnn'."
) from e
config = {}
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
config['backend'] = 0 # CPU
config['numThread'] = 4 # 线程数
config["precision"] = "low" # 当硬件支持(armv8.2)时使用fp16推理
config["backend"] = 0 # CPU
config["numThread"] = 4 # 线程数
im, width, length = read_modnet_image(input_image, ref_size=512)
rt = nn.create_runtime_manager((config,))
net = nn.load_module_from_file(checkpoint_path, ['input1'], ['output1'], runtime_manager=rt)
net = nn.load_module_from_file(
checkpoint_path, ["input1"], ["output1"], runtime_manager=rt
)
input_var = expr.convert(im, expr.NCHW)
output_var = net.forward(input_var)
matte = expr.convert(output_var, expr.NCHW)
matte = matte.read()#var转换为np
matte = matte.read() # var转换为np
matte = (matte * 255).astype("uint8")
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
Expand All @@ -71,6 +75,7 @@ def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):

return output_image


def extract_human_modnet_photographic_portrait_matting(ctx: Context):
"""
人像抠图
Expand All @@ -84,8 +89,11 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()


def extract_human_mnn_modnet(ctx: Context):
matting_image = get_mnn_modnet_matting(ctx.processing_image, WEIGHTS["mnn_hivision_modnet"])
matting_image = get_mnn_modnet_matting(
ctx.processing_image, WEIGHTS["mnn_hivision_modnet"]
)
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()

Expand Down

0 comments on commit ad69622

Please sign in to comment.