Skip to content

Commit

Permalink
Merge pull request opendatalab#1284 from IMSUVEN/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
myhloli authored Dec 13, 2024
2 parents 8ccfff6 + be01039 commit 55e9bb9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 26 deletions.
72 changes: 62 additions & 10 deletions magic_pdf/model/batch_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,80 @@ def __init__(self, model: CustomPEKModel, batch_ratio: int):
self.batch_ratio = batch_ratio

def __call__(self, images: list) -> list:
images_layout_res = []

layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
images_layout_res = []
for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
images_layout_res = self.model.layout_model.batch_predict(
images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images = []
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {"poly": [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
layout_images.append(new_image)
modified_images.append([image_index, useful_list])
else:
layout_images.append(pil_img)

images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
)

for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res["poly"])):
if i % 2 == 0:
res["poly"][i] = (
res["poly"][i] - useful_list[0] + useful_list[2]
)
else:
res["poly"][i] = (
res["poly"][i] - useful_list[1] + useful_list[3]
)
logger.info(
f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}"
)

if self.model.apply_formula:
# 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
)
logger.info(
f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}"
)

# 公式识别
mfr_start_time = time.time()
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
logger.info(
f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}"
)

# 清理显存
clean_vram(self.model.device, vram_threshold=8)

ocr_time = 0
ocr_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
Expand Down Expand Up @@ -99,12 +143,8 @@ def __call__(self, images: list) -> list:
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)

ocr_cost = round(time.time() - ocr_start, 2)
if self.model.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
ocr_time += time.time() - ocr_start
ocr_count += len(ocr_res_list)

# 表格识别 table recognition
if self.model.apply_table:
Expand Down Expand Up @@ -146,7 +186,17 @@ def __call__(self, images: list) -> list:
logger.warning(
"table recognition processing fails, not get html return"
)
logger.info(f"table time: {round(time.time() - table_start, 2)}")
table_time += time.time() - table_start
table_count += len(table_res_list)

if self.model.apply_ocr:
logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}")
else:
logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}")
if self.model.apply_table:
logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}")

return images_layout_res


def doc_batch_analyze(
Expand Down Expand Up @@ -223,6 +273,8 @@ def doc_batch_analyze(
model_json.append(page_dict)

# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory()
logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}")

return InferenceResult(model_json, dataset)
19 changes: 11 additions & 8 deletions magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ def predict(self, image):
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
doclayout_yolo_res = self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
Expand Down
19 changes: 11 additions & 8 deletions magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ def predict(self, image):
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
mfd_res = self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
).cpu()
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=True,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res

0 comments on commit 55e9bb9

Please sign in to comment.