Skip to content

Commit

Permalink
Merge pull request #20 from gibchikafa/fix_dependancy
Browse files Browse the repository at this point in the history
Remove tritonclient dependancy
  • Loading branch information
gibchikafa authored Jan 31, 2024
2 parents 2cd017e + 999c2af commit 790dc37
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 2 deletions.
127 changes: 126 additions & 1 deletion python/kserve/kserve/protocol/infer_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,142 @@

from typing import Optional, List, Dict

import struct
import numpy
import numpy as np
import pandas as pd
from tritonclient.utils import raise_error, serialize_byte_tensor

from ..constants.constants import GRPC_CONTENT_DATATYPE_MAPPINGS
from ..errors import InvalidInput
from ..protocol.grpc.grpc_predict_v2_pb2 import ModelInferRequest, InferTensorContents, ModelInferResponse
from ..utils.numpy_codec import to_np_dtype, from_np_dtype


def raise_error(msg):
"""
Raise error with the provided message
"""
raise InferenceServerException(msg=msg) from None


def serialize_byte_tensor(input_tensor):
"""
Serializes a bytes tensor into a flat numpy array of length prepended
bytes. The numpy array should use dtype of np.object. For np.bytes,
numpy will remove trailing zeros at the end of byte sequence and because
of this it should be avoided.
Parameters
----------
input_tensor : np.array
The bytes tensor to serialize.
Returns
-------
serialized_bytes_tensor : np.array
The 1-D numpy array of type uint8 containing the serialized bytes in row-major form.
Raises
------
InferenceServerException
If unable to serialize the given tensor.
"""

if input_tensor.size == 0:
return np.empty([0], dtype=np.object_)

# If the input is a tensor of string/bytes objects, then must flatten those into
# a 1-dimensional array containing the 4-byte byte size followed by the
# actual element bytes. All elements are concatenated together in row-major
# order.

if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_):
raise_error("cannot serialize bytes tensor: invalid datatype")

flattened_ls = []
# 'C' order is row-major.
for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
# If directly passing bytes to BYTES type,
# don't convert it to str as Python will encode the
# bytes which may distort the meaning
if input_tensor.dtype == np.object_:
if type(obj.item()) == bytes:
s = obj.item()
else:
s = str(obj.item()).encode("utf-8")
else:
s = obj.item()
flattened_ls.append(struct.pack("<I", len(s)))
flattened_ls.append(s)
flattened = b"".join(flattened_ls)
flattened_array = np.asarray(flattened, dtype=np.object_)
if not flattened_array.flags["C_CONTIGUOUS"]:
flattened_array = np.ascontiguousarray(flattened_array, dtype=np.object_)
return flattened_array


class InferenceServerException(Exception):
"""Exception indicating non-Success status.
Parameters
----------
msg : str
A brief description of error
status : str
The error code
debug_details : str
The additional details on the error
"""

def __init__(self, msg, status=None, debug_details=None):
self._msg = msg
self._status = status
self._debug_details = debug_details

def __str__(self):
msg = super().__str__() if self._msg is None else self._msg
if self._status is not None:
msg = "[" + self._status + "] " + msg
return msg

def message(self):
"""Get the exception message.
Returns
-------
str
The message associated with this exception, or None if no message.
"""
return self._msg

def status(self):
"""Get the status of the exception.
Returns
-------
str
Returns the status of the exception
"""
return self._status

def debug_details(self):
"""Get the detailed information about the exception
for debugging purposes
Returns
-------
str
Returns the exception details
"""
return self._debug_details


class InferInput:
_name: str
_shape: List[int]
Expand Down
1 change: 0 additions & 1 deletion python/kserve/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ prometheus-client = "^0.13.1"
orjson = "^3.8.0"
httpx = "^0.23.0"
timing-asgi = "^0.3.0"
tritonclient = "^2.18.0"
tabulate = "^0.9.0"
pandas = ">=1.3.5"

Expand Down

0 comments on commit 790dc37

Please sign in to comment.