-
Notifications
You must be signed in to change notification settings - Fork 167
/
app.py
350 lines (281 loc) · 11 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import os.path
import sys
import tempfile
import time
from datetime import datetime
import gradio as gr
import numpy as np
import paddle
from PIL import Image
# 设置使用的GPU设备
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 模型配置
model_path = "PaddleMIX/PPDocBee-2B-1129"
dtype = "bfloat16" # V100请改成float16
# 全局变量定义
model = None
processor = None
min_pixels = 256 * 28 * 28 # 最小像素数
max_pixels = 48 * 48 * 28 * 28 # 最大像素数
SERVER_NAME = "localhost"
SERVER_PORR = 8080
def check_and_install_paddlemix():
try:
from paddlemix.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
print("Required Qwen2VL model successfully installed")
except ImportError:
print("Failed to install required Qwen2VL model even after running the script")
sys.exit(1)
# 在继续之前检查所需模型
check_and_install_paddlemix()
from paddlemix.models.qwen2_vl import MIXQwen2Tokenizer
from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from paddlemix.processors.qwen2_vl_processing import (
Qwen2VLImageProcessor,
Qwen2VLProcessor,
process_vision_info,
)
# 示例使用HTTP链接
EXAMPLES = [
[
"维修保养、其他注意事项的注意点中,电池需为什么型号的?",
"paddlemix/demo_images/shuomingshu_20.png",
],
[
"产品期限是多久?",
"paddlemix/demo_images/shuomingshu_39.png",
],
]
class ImageCache:
"""图片缓存管理类"""
def __init__(self):
"""初始化图片缓存"""
self.temp_dir = tempfile.mkdtemp()
self.current_image = None
self.is_example = False # 标记当前图片是否为示例图片
print(f"Created temporary directory for image cache: {self.temp_dir}")
def cleanup_previous(self):
"""清理之前的缓存图片"""
if self.current_image and os.path.exists(self.current_image) and not self.is_example:
try:
os.unlink(self.current_image)
print(f"Cleaned up previous image: {self.current_image}")
except Exception as e:
print(f"Error cleaning up previous image: {e}")
def cache_image(self, image_path, is_example=False):
"""
缓存图片并返回缓存路径
Args:
image_path: 图片文件路径
is_example: 是否为示例图片
Returns:
缓存后的图片路径
"""
if not image_path:
return None
try:
# 如果是示例图片且已经在使用中,直接返回
if is_example and self.current_image == image_path and self.is_example:
return self.current_image
# 创建安全的文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_hash = hashlib.md5(str(time.time()).encode()).hexdigest()[:8]
_, ext = os.path.splitext(image_path)
if not ext:
ext = ".jpg" # 默认扩展名
new_filename = f"image_{timestamp}_{file_hash}{ext}"
# 在临时目录中创建新路径
new_path = os.path.join(self.temp_dir, new_filename) if not is_example else image_path
if not is_example:
# 处理上传的图片文件
with Image.open(image_path) as img:
# 如果需要,转换为RGB
if img.mode != "RGB":
img = img.convert("RGB")
img.save(new_path)
# 更新当前图片之前清理之前的图片
self.cleanup_previous()
self.current_image = new_path
self.is_example = is_example
return new_path
except Exception as e:
print(f"Error caching image: {e}")
return image_path
# 创建全局图片缓存管理器
image_cache = ImageCache()
def load_model():
"""加载模型并进行内存优化"""
global model, processor
if model is None:
# 加载模型和处理器
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
dtype=dtype,
)
image_processor = Qwen2VLImageProcessor()
tokenizer = MIXQwen2Tokenizer.from_pretrained(model_path)
processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)
# 设置为评估模式
model.eval()
del tokenizer
return model, processor
def clear_cache():
"""清理GPU缓存"""
if paddle.device.cuda.memory_allocated() > 0:
paddle.device.cuda.empty_cache()
import gc
gc.collect()
def multimodal_understanding(image, question, seed=42, top_p=0.95, temperature=0.1):
"""
多模态理解主函数
Args:
image: 输入图片
question: 问题文本
seed: 随机种子
top_p: 采样参数
temperature: 温度参数
Yields:
处理状态和结果
"""
# 输入验证
if not image:
yield "⚠️ 请上传图片后再开始对话。"
return
if not question or question.strip() == "":
yield "⚠️ 请输入您的问题后再开始对话。"
return
try:
start_time = time.time()
yield "🔄 正在处理您的请求,请稍候..."
# 检查超时
if time.time() - start_time > 200:
yield "⏳ 系统当前用户繁多,请等待10分钟后再次尝试。感谢您的理解!"
return
clear_cache()
# 设置随机种子
paddle.seed(seed)
np.random.seed(seed)
# 处理图片缓存
is_example = any(image == example[1] for example in EXAMPLES)
cached_image = image_cache.cache_image(image, is_example=is_example)
if not cached_image:
return "图片处理失败,请检查图片格式是否正确。"
# 构建提示文本
prompts = question + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
# 构建消息
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": cached_image,
},
{"type": "text", "text": prompts},
],
}
]
yield "模型正在分析图片内容..."
# 处理视觉信息
image_inputs, video_inputs = process_vision_info(messages)
image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
text = f"<|im_start|>system\n你是一个非常棒的多模态理解的AI助手。<|im_end|>\n<|im_start|>user\n{image_pad_token}{prompts}<|im_end|>\n<|im_start|>assistant\n"
# 生成回答
with paddle.no_grad():
inputs = processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pd"
)
yield "正在生成回答..."
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
top_p=top_p,
temperature=temperature,
num_beams=1,
do_sample=True,
use_cache=True,
)
output_text = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# 清理内存
del inputs, generated_ids
clear_cache()
yield output_text
except Exception as e:
error_message = f"处理过程中出现错误: {str(e)}\n请重试或在评论区留下你的问题。"
return error_message
def process_example(question, image):
"""处理示例图片的包装函数"""
cached_path = image_cache.cache_image(image, is_example=True)
return multimodal_understanding(cached_path, question)
def handle_image_upload(image):
"""处理图片上传"""
if image is None:
return None
try:
cached_path = image_cache.cache_image(image, is_example=False)
return cached_path
except Exception as e:
print(f"Error handling image upload: {e}")
return None
# model, processor = load_model()
# # image = "/home/aistudio/work/doc-lark/PaddleMIX/paddlemix/demo_images/examples_image1.jpg"
# print(multimodal_understanding(EXAMPLES[1][1],EXAMPLES[1][0]))
# Gradio界面配置
with gr.Blocks() as demo:
gr.Markdown(
value="""
# 🤖 PP-DocBee(2B): Multimodal Document Understanding Demo
📚 原始模型来自 [PaddleMIX](https://github.com/PaddlePaddle/PaddleMIX) (🌟 一个基于飞桨PaddlePaddle框架构建的多模态大模型套件)
"""
)
with gr.Row():
image_input = gr.Image(type="filepath", label="📷 Upload Image or Input URL")
with gr.Column():
question_input = gr.Textbox(label="💭 Question", placeholder="Enter your question here...")
und_seed_input = gr.Number(label="🎲 Seed", precision=0, value=42)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="📊 Top P")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="🌡️ Temperature")
image_input.upload(fn=handle_image_upload, inputs=[image_input], outputs=[image_input])
understanding_button = gr.Button("💬 Chat", variant="primary")
understanding_output = gr.Textbox(label="🤖 Response", interactive=False)
gr.Examples(
examples=EXAMPLES,
inputs=[question_input, image_input],
outputs=understanding_output,
fn=process_example,
cache_examples=True,
run_on_click=True,
)
# 加载模型
clear_cache()
model, processor = load_model()
clear_cache()
understanding_button.click(
fn=multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
outputs=understanding_output,
api_name="chat",
)
if __name__ == "__main__":
# 创建队列
demo.queue()
demo.launch(server_name=SERVER_NAME, server_port=SERVER_PORR, share=True, ssr_mode=False, max_threads=1) # 限制并发请求数