Skip to content

Commit

Permalink
feat: mnn matting model (#73)
Browse files Browse the repository at this point in the history
* 基于mnn框架推理

* mnn导入方式修改

* Update requirements.txt

---------

Co-authored-by: Ze-Yi LIN <[email protected]>
  • Loading branch information
zjkhahah and Zeyi-Lin authored Sep 7, 2024
1 parent 6617dca commit 9740752
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
36 changes: 36 additions & 0 deletions hivision/creator/human_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
"weights",
"modnet_photographic_portrait_matting.onnx",
),
"mnn_hivision_modnet": os.path.join(
os.path.dirname(__file__),
"weights",
"mnn_hivision_modnet.mnn",
)
}


Expand All @@ -40,6 +45,32 @@ def extract_human(ctx: Context):
ctx.matting_image = ctx.processing_image.copy()


def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
try:
import MNN.expr as expr
import MNN.nn as nn
except ImportError as e:
raise ImportError("MNN模块未安装或导入错误。请确保已安装MNN库,使用命令 'pip install mnn' 安装。") from e
config = {}
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
config['backend'] = 0 # CPU
config['numThread'] = 4 # 线程数
im, width, length = read_modnet_image(input_image, ref_size=512)
rt = nn.create_runtime_manager((config,))
net = nn.load_module_from_file(checkpoint_path, ['input1'], ['output1'], runtime_manager=rt)
input_var = expr.convert(im, expr.NCHW)
output_var = net.forward(input_var)
matte = expr.convert(output_var, expr.NCHW)
matte = matte.read()#var转换为np
matte = (matte * 255).astype("uint8")
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
b, g, r = cv2.split(np.uint8(input_image))

output_image = cv2.merge((b, g, r, mask))

return output_image

def extract_human_modnet_photographic_portrait_matting(ctx: Context):
"""
人像抠图
Expand All @@ -53,6 +84,11 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()

def extract_human_mnn_modnet(ctx: Context):
matting_image = get_mnn_modnet_matting(ctx.processing_image, WEIGHTS["mnn_hivision_modnet"])
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()


def hollow_out_fix(src: np.ndarray) -> np.ndarray:
"""
Expand Down
5 changes: 4 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hivision.creator.human_matting import (
extract_human_modnet_photographic_portrait_matting,
extract_human,
extract_human_mnn_modnet,
)

parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。")
Expand All @@ -24,7 +25,7 @@
"add_background",
"generate_layout_photos",
]
MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting"]
MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting", "mnn_hivision_modnet"]
RENDER = [0, 1, 2]

parser.add_argument(
Expand Down Expand Up @@ -64,6 +65,8 @@
creator.matting_handler = extract_human
elif args.matting_model == "modnet_photographic_portrait_matting":
creator.matting_handler = extract_human_modnet_photographic_portrait_matting
elif args.matting_model == "mnn_hivision_modnet":
creator.matting_handler = extract_human_mnn_modnet

root_dir = os.path.dirname(os.path.abspath(__file__))
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ opencv-python>=4.8.1.78
onnxruntime>=1.15.0
numpy<=1.26.4
requests
mtcnn-runtime
mtcnn-runtime

0 comments on commit 9740752

Please sign in to comment.