You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I try to use batch inference with tensorRT. I modifile some path of original code. I hope the code below will help to everybody want to use batch inference with tensorRT. If I miss something please correct me. Thank you!
########################################
## trt with batch inference ##
########################################
import os
from yolov3.utils import image_preprocess, postprocess_boxes, nms, draw_bbox, read_class_names
from tensorflow.keras.preprocessing import image
input_size = 416
score_threshold = 0.3
iou_threshold = 0.5
batch_inference = 4
image_dir = "... image dir (0.png, 1.png, 2.png, 3.png)..."
batched_input = np.zeros((batch_inference, input_size, input_size, 3), dtype=np.float32)
list_original_image = []
for i in range(batch_inference):
image_path = os.path.join(image_dir, str(i)+'.png')
print(image_path)
original_image = cv2.imread(image_path)
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
list_original_image.append(original_image)
image_data = image_preprocess(np.copy(original_image), [input_size, input_size])
print('image_data.shape 1',image_data.shape)
image_data = image_data[np.newaxis, ...].astype(np.float32)
print('image_data.shape 2',image_data.shape)
batched_input[i, :] = image_data
batched_input = tf.constant(batched_input)
print('batched_input.shape',batched_input.shape)
start_time = time.time()
result = Yolo(batched_input)
print((time.time() - start_time))
for num in range(batch_inference):
pred_bbox = []
for key, value in result.items():
print('key', key)
value = value.numpy()
new_dim = np.expand_dims(value[num], axis=0)
pred_bbox.append(new_dim)
pred_bbox_tf = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
pred_bbox_tf = tf.concat(pred_bbox_tf, axis=0)
bboxes = postprocess_boxes(pred_bbox_tf, list_original_image[num], input_size, score_threshold)
bboxes = nms(bboxes, iou_threshold, method='nms')
image = draw_bbox(list_original_image[num], bboxes, CLASSES="model_data/license_plate_names.txt", rectangle_colors=(255,0,0))
plt.imshow(image)
plt.show()
print('------')
The text was updated successfully, but these errors were encountered:
I try to use batch inference with tensorRT. I modifile some path of original code. I hope the code below will help to everybody want to use batch inference with tensorRT. If I miss something please correct me. Thank you!
The text was updated successfully, but these errors were encountered: