Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tensorrt backend #33

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vcap/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"scikit-learn==0.22.2",
"numpy>=1.16,<2",
"tensorflow-gpu==1.15.4",
"pycuda>=2019.1.1",
"tensorrt==7.2.3.4",
],
extras_require={
"tests": test_packages,
Expand Down
11 changes: 5 additions & 6 deletions vcap/vcap/modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def __init__(self, frame: np.ndarray):

def resize(self, resize_width: int, resize_height: int,
resize_type: ResizeType):

frame_width = self.frame.shape[1]
frame_height = self.frame.shape[0]

Expand Down Expand Up @@ -251,11 +250,11 @@ def resize(self, resize_width: int, resize_height: int,
# Account for scaling
scale_width = new_width / frame_width
scale_height = new_height / frame_height
self._operations.append(
(self._OperationType.SCALE, (scale_width, scale_height))
)

self.frame = cv2.resize(self.frame, (new_width, new_height))
if new_width != frame_width or new_height != frame_height:
self._operations.append(
(self._OperationType.SCALE, (scale_width, scale_height))
)
self.frame = cv2.resize(self.frame, (new_width, new_height))

return self

Expand Down
1 change: 1 addition & 0 deletions vcap_utils/vcap_utils/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .openface_encoder import OpenFaceEncoder
from .base_encoder import BaseEncoderBackend
from .backend_rpc_process import BackendRpcProcess
from .base_tensorrt import BaseTensorRTBackend
from .load_utils import parse_dataset_metadata_bytes, parse_tf_model_bytes
from .predictions import (
EncodingPrediction,
Expand Down
274 changes: 274 additions & 0 deletions vcap_utils/vcap_utils/backends/base_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
import numpy as np

import pycuda.driver as cuda
import tensorrt as trt
import cv2

from typing import Dict, List, Tuple, Optional, Any

from vcap import (
Resize,
DETECTION_NODE_TYPE,
OPTION_TYPE,
BaseStreamState,
BaseBackend,
rect_to_coords,
DetectionNode,
)


class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem

def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

def __repr__(self):
return self.__str__()


class AllocatedBuffer:
def __init__(self, inputs_, outputs_, bindings_, stream_):
self.inputs = inputs_
self.outputs = outputs_
self.bindings = bindings_
self.stream = stream_


class BaseTensorRTBackend(BaseBackend):
def __init__(self, engine_bytes, width, height, device_id):
apockill marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
gpu_devide_id = int(device_id[4:])
cuda.init()
apockill marked this conversation as resolved.
Show resolved Hide resolved
dev = cuda.Device(gpu_devide_id)
self.ctx = dev.make_context()
apockill marked this conversation as resolved.
Show resolved Hide resolved
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
apockill marked this conversation as resolved.
Show resolved Hide resolved
self.trt_runtime = trt.Runtime(TRT_LOGGER)
# load the engine
self.trt_engine = self.trt_runtime.deserialize_cuda_engine(engine_bytes)
# create execution context
self.context = self.trt_engine.create_execution_context()
# create buffers for inference
self.buffers = {}
for batch_size in range(1, self.trt_engine.max_batch_size + 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving a comment here to remind me:

We need to do some memory measurement to figure out if all of these buffers are necessary. I wonder if allocating a buffer for Batch-Size [1, 2, 5, 10] or other combinations might be better.

THings to test:

  1. How many tensorrt models can the NX hold?
  2. How much extra memory does this allocate (relating to Use an s3 bucket to store large files instead of Git LFS #1)
  3. What's the speed performance if we do [1, 2, 10] vs [1, 2, 5, 10], vs [1, 2, 3, 4, 5, 6,...10]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question we'll have to figure out: Should this be configurable via the init?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do some tests to figure how much memory is needed for those buffers. Another thought is that if we don't get performance improvement with a larger batch size, we don't have to do that. Based on my tests, larger batch size will improve the inference time by 10% but lower the preprocessing performance, the overall performance is even a little lower than a small batch size.

inputs, outputs, bindings, stream = self.allocate_buffers(
batch_size=batch_size)
self.buffers[batch_size] = AllocatedBuffer(inputs, outputs, bindings,
stream)

self.engine_width = width
self.engine_height = height

# preallocate resources for post process
# todo: post process is only need for detectors
self._prepare_post_process()

def batch_predict(self, input_data_list: List[Any]) -> List[Any]:
apockill marked this conversation as resolved.
Show resolved Hide resolved
task_size = len(input_data_list)
curr_index = 0
while curr_index < task_size:
Copy link
Contributor

@apockill apockill Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic may need to be revisited if we decide not to have buffers [0->10], and instead have combinations of [1, 2, 5, 10], for example

if curr_index + self.trt_engine.max_batch_size <= task_size:
end_index = curr_index + self.trt_engine.max_batch_size
else:
end_index = task_size
batch = input_data_list[curr_index:end_index]
curr_index = end_index
for result in self._process_batch(batch):
yield result

def _process_batch(self, input_data: List[np.array]) -> List[List[float]]:
batch_size = len(input_data)
prepared_buffer = self.buffers[batch_size]
inputs = prepared_buffer.inputs
outputs = prepared_buffer.outputs
bindings = prepared_buffer.bindings
stream = prepared_buffer.stream
# todo: get dtype from engine
inputs[0].host = np.ascontiguousarray(input_data, dtype=np.float32)

detections = self.do_inference(
bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=batch_size
)
return detections
apockill marked this conversation as resolved.
Show resolved Hide resolved

def process_frame(self, frame: np.ndarray, detection_node: DETECTION_NODE_TYPE,
apockill marked this conversation as resolved.
Show resolved Hide resolved
options: Dict[str, OPTION_TYPE],
state: BaseStreamState) -> DETECTION_NODE_TYPE:
pass

def prepare_inputs(self, frame: np.ndarray, transpose: bool, normalize: bool,
mean_subtraction: Optional[Tuple] = None) -> \
Tuple[np.array, Resize]:
resize = Resize(frame).resize(self.engine_width, self.engine_height,
Resize.ResizeType.EXACT)
if transpose:
resize.frame = np.transpose(resize.frame, (2, 0, 1))
if normalize:
resize.frame = (1.0 / 255.0) * resize.frame
if mean_subtraction is not None:
if len(mean_subtraction) != 3:
raise RuntimeError("Invalid mean subtraction")
resize.frame = resize.frame.astype("float64")
resize.frame[..., 0] -= mean_subtraction[0]
resize.frame[..., 1] -= mean_subtraction[1]
resize.frame[..., 2] -= mean_subtraction[2]
return resize.frame, resize

def allocate_buffers(self, batch_size: int = 1) -> \
Tuple[List[HostDeviceMem], List[HostDeviceMem], List[int], cuda.Stream]:
"""Allocates host and device buffer for TRT engine inference.
Args:
batch_size: batch size for the input/output memory
Returns:
inputs [HostDeviceMem]: engine input memory
outputs [HostDeviceMem]: engine output memory
bindings [int]: buffer to device bindings
stream (cuda.Stream): cuda stream for engine inference synchronization
"""
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in self.trt_engine:
size = trt.volume(self.trt_engine.get_binding_shape(binding)) * batch_size
dtype = trt.nptype(self.trt_engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.trt_engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream

def do_inference(self, bindings: List[int], inputs: List[HostDeviceMem], outputs: List[HostDeviceMem],
Copy link
Contributor

@apockill apockill Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def do_inference(self, bindings: List[int], inputs: List[HostDeviceMem], outputs: List[HostDeviceMem],
def _do_inference(self, bindings: List[int],
inputs: List[HostDeviceMem],
outputs: List[HostDeviceMem],
stream: cuda.Stream,
batch_size: int = 1) -> List[List[float]]:

stream: cuda.Stream, batch_size: int = 1) -> List[List[float]]:
# Transfer input data to the GPU.
self.ctx.push()
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
# todo: use async or sync api?
# According to https://docs.nvidia.com/deeplearning/tensorrt/best-practices/index.html#optimize-python
# the performance should be almost identical
self.context.execute(
batch_size=batch_size, bindings=bindings
)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
batch_outputs = []
for out in outputs:
entire_out_array = np.array(out.host)
out_array_by_batch = np.split(entire_out_array, batch_size)
out_lists = [out_array.tolist() for out_array in out_array_by_batch]
batch_outputs.append(out_lists)
final_outputs = []
for i in range(len(batch_outputs[0])):
final_output = []
for batch_output in batch_outputs:
final_output.append(batch_output[i])
final_outputs.append(final_output)
apockill marked this conversation as resolved.
Show resolved Hide resolved
self.ctx.pop()
return final_outputs

def _prepare_post_process(self):
Copy link
Contributor

@apockill apockill Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to think that are too many constants and GridNet specific functions here, and it might be easier to make a separate class specifically for parsing GridNet bounding boxes.

For now, let's clean up the rest of the code first, then discuss how that would work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constants are only necessary for detectors, maybe we need another parameter like is_detector in the constructor to indicate if this capsule a detector or classifier?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can check if these constants exist before we call the post process function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I'm thinking that this is super duper specific to GridNet detectors particularly. Maybe we can just offer a function that for parsing GridNet detector outputs, and name it as such.

class GridNetParser:
   def __init__(parameters):
     ...
   def parse_detection_results(prediction):
     ...

class BaseTensorRTBackend:
   ...

The benefit would be to separate all of these GridNet specific parameters out of the BaseTensorRTBackend 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, we should have separate parsers for different architectures.

self.stride = 16
apockill marked this conversation as resolved.
Show resolved Hide resolved
self.box_norm = 35.0
self.grid_h = int(self.engine_height / self.stride)
self.grid_w = int(self.engine_width / self.stride)
self.grid_size = self.grid_h * self.grid_w

self.grid_centers_w = []
self.grid_centers_h = []

for i in range(self.grid_h):
value = (i * self.stride + 0.5) / self.box_norm
self.grid_centers_h.append(value)

for i in range(self.grid_w):
value = (i * self.stride + 0.5) / self.box_norm
self.grid_centers_w.append(value)
apockill marked this conversation as resolved.
Show resolved Hide resolved

def _apply_box_norm(self, o1: float, o2: float, o3: float, o4: float, x: int, y: int) -> \
Tuple[float, float, float, float]:
"""
Applies the GridNet box normalization
Args:
o1 (float): first argument of the result
o2 (float): second argument of the result
o3 (float): third argument of the result
o4 (float): fourth argument of the result
x: row index on the grid
y: column index on the grid

Returns:
float: rescaled first argument
float: rescaled second argument
float: rescaled third argument
float: rescaled fourth argument
"""
o1 = (o1 - self.grid_centers_w[x]) * -self.box_norm
o2 = (o2 - self.grid_centers_h[y]) * -self.box_norm
o3 = (o3 + self.grid_centers_w[x]) * self.box_norm
o4 = (o4 + self.grid_centers_h[y]) * self.box_norm
return o1, o2, o3, o4
BestDriverCN marked this conversation as resolved.
Show resolved Hide resolved

def parse_detection_results(
self, results: List[List[float]],
resize: Resize,
label_map: Dict[int, str],
min_confidence: float = 0.0,
) -> List[DetectionNode]:
bbs = []
class_ids = []
scores = []
for c in label_map.keys():

x1_idx = c * 4 * self.grid_size
y1_idx = x1_idx + self.grid_size
x2_idx = y1_idx + self.grid_size
y2_idx = x2_idx + self.grid_size

boxes = results[0]
for h in range(self.grid_h):
for w in range(self.grid_w):
i = w + h * self.grid_w
score = results[1][c * self.grid_size + i]
if score >= min_confidence:
o1 = boxes[x1_idx + w + h * self.grid_w]
o2 = boxes[y1_idx + w + h * self.grid_w]
o3 = boxes[x2_idx + w + h * self.grid_w]
o4 = boxes[y2_idx + w + h * self.grid_w]
o1, o2, o3, o4 = self._apply_box_norm(o1, o2, o3, o4, w, h)
xmin = int(o1)
ymin = int(o2)
xmax = int(o3)
ymax = int(o4)
bbs.append([xmin, ymin, xmax - xmin, ymax - ymin])
class_ids.append(c)
scores.append(float(score))
indexes = cv2.dnn.NMSBoxes(bbs, scores, min_confidence, 0.5)
detections = []
for idx in indexes:
idx = int(idx)
xmin, ymin, w, h = bbs[idx]
class_id = class_ids[idx]
class_name = label_map[class_id]
detections.append(
DetectionNode(
name=class_name,
coords=rect_to_coords(
[xmin, ymin, (xmin + w), (ymin + h)]
),
extra_data={"detection_confidence": scores[idx]},
)
)
resize.scale_and_offset_detection_nodes(detections)
return detections
apockill marked this conversation as resolved.
Show resolved Hide resolved