diff --git a/python/kserve/kserve/protocol/infer_type.py b/python/kserve/kserve/protocol/infer_type.py index cea572a00f3..ffd6dcabd0b 100644 --- a/python/kserve/kserve/protocol/infer_type.py +++ b/python/kserve/kserve/protocol/infer_type.py @@ -14,10 +14,10 @@ from typing import Optional, List, Dict +import struct import numpy import numpy as np import pandas as pd -from tritonclient.utils import raise_error, serialize_byte_tensor from ..constants.constants import GRPC_CONTENT_DATATYPE_MAPPINGS from ..errors import InvalidInput @@ -25,6 +25,131 @@ from ..utils.numpy_codec import to_np_dtype, from_np_dtype +def raise_error(msg): + """ + Raise error with the provided message + """ + raise InferenceServerException(msg=msg) from None + + +def serialize_byte_tensor(input_tensor): + """ + Serializes a bytes tensor into a flat numpy array of length prepended + bytes. The numpy array should use dtype of np.object. For np.bytes, + numpy will remove trailing zeros at the end of byte sequence and because + of this it should be avoided. + + Parameters + ---------- + input_tensor : np.array + The bytes tensor to serialize. + + Returns + ------- + serialized_bytes_tensor : np.array + The 1-D numpy array of type uint8 containing the serialized bytes in row-major form. + + Raises + ------ + InferenceServerException + If unable to serialize the given tensor. + """ + + if input_tensor.size == 0: + return np.empty([0], dtype=np.object_) + + # If the input is a tensor of string/bytes objects, then must flatten those into + # a 1-dimensional array containing the 4-byte byte size followed by the + # actual element bytes. All elements are concatenated together in row-major + # order. + + if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_): + raise_error("cannot serialize bytes tensor: invalid datatype") + + flattened_ls = [] + # 'C' order is row-major. + for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"): + # If directly passing bytes to BYTES type, + # don't convert it to str as Python will encode the + # bytes which may distort the meaning + if input_tensor.dtype == np.object_: + if type(obj.item()) == bytes: + s = obj.item() + else: + s = str(obj.item()).encode("utf-8") + else: + s = obj.item() + flattened_ls.append(struct.pack("