Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: making torchserve usable without torch #3351

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/ts.protocol.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ ts.protocol.otf\_message\_handler module
:undoc-members:
:show-inheritance:

ts.protocol.otf\_torch\_message\_handler module
----------------------------------------

.. automodule:: ts.protocol.otf_torch_message_handler
:members:
:undoc-members:
:show-inheritance:


Module contents
---------------
Expand Down
8 changes: 8 additions & 0 deletions docs/api/ts.torch_handler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ Subpackages
Submodules
----------

ts.torch\_handler.abstract\_handler module
--------------------------------------

.. automodule:: ts.torch_handler.abstract_handler
:members:
:undoc-members:
:show-inheritance:

ts.torch\_handler.base\_handler module
--------------------------------------

Expand Down
12 changes: 10 additions & 2 deletions docs/custom_service.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ Following is applicable to all types of custom handlers
* **context** - Is the TorchServe [context](https://github.com/pytorch/serve/blob/master/ts/context.py). You can use following information for customization
model_name, model_dir, manifest, batch_size, gpu etc.

### Start with BaseHandler!
[BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py) implements most of the functionality you need. You can derive a new class from it, as shown in the examples and default handlers. Most of the time, you'll only need to override `preprocess` or `postprocess`.
### Start with BaseHandler or AbstractHandler!
[BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py)
implements most of the functionality you need for torch models.
You can derive a new class from it, as shown in the examples and default handlers.
Most of the time, you'll only need to override `preprocess` or `postprocess`.

[AbstractHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/abstract_handler.py)
implements most of the handling functionality without being torch-specific. You need to implement the model loading
yourself, but it lets you use TorchServe and its great features without needing to install torch, so it's useful for
scikit-learn or tensorflow models in case you need something flexible with executing python during the serving.

### Custom handler with `module` level entry point

Expand Down
2 changes: 1 addition & 1 deletion docs/default_handlers.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ For a more comprehensive list of available handlers make sure to check out the [
- [object_detector](https://github.com/pytorch/serve/tree/master/examples/object_detector/index_to_name.json)

### Contributing
We welcome new contributed handlers, if your usecase isn't covered by one of the existing default handlers please follow the below steps to contribute it
We welcome new contributed handlers, if your use-case isn't covered by one of the existing default handlers please follow the below steps to contribute it
1. Write a new class derived from [BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py). Add it as a separate file in `ts/torch_handler/`
2. Update `model-archiver/model_packaging.py` to add in your classes name
3. Run and update the unit tests in [unit_tests](https://github.com/pytorch/serve/tree/master/ts/torch_handler/unit_tests). As always, make sure to run [torchserve_sanity.py](https://github.com/pytorch/serve/tree/master/torchserve_sanity.py) before submitting.
2 changes: 1 addition & 1 deletion examples/LLM/llama/chat_app/docker/llama_cpp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from llama_cpp import Llama

from ts.protocol.otf_message_handler import send_intermediate_predict_response
from ts.protocol.otf_torch_message_handler import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion ts/handler_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from ts.context import Context
from ts.protocol.otf_message_handler import create_predict_response
from ts.protocol.otf_torch_message_handler import create_predict_response


def import_class(class_name: str, module_prefix=None):
Expand Down
3 changes: 2 additions & 1 deletion ts/model_service_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
ModelServiceWorker is the worker that is started by the MMS front-end.
class TorchModelServiceWorker(object):
is the worker that is started by the MMS front-end.
Communication message format: binary encoding
"""

Expand Down
120 changes: 1 addition & 119 deletions ts/protocol/otf_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
OTF Codec
"""

import io
import json
import logging
import os
import struct
import sys
import time
from builtins import bytearray, bytes

import torch

from ts.utils.util import deprecated
from builtins import bytearray

bool_size = 1
int_size = 4
Expand Down Expand Up @@ -53,108 +48,6 @@ def encode_response_headers(resp_hdr_map):
return msg


def create_predict_response(
ret, req_id_map, message, code, context=None, ts_stream_next=False
):
"""
Create inference response.

:param context:
:param ret:
:param req_id_map:
:param message:
:param code:
:return:
"""
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None

msg = bytearray()
msg += struct.pack("!i", code)

buf = message.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf

for idx in req_id_map:
req_id = req_id_map.get(idx).encode("utf-8")
msg += struct.pack("!i", len(req_id))
msg += req_id

if context is None:
# Encoding Content-Type
msg += struct.pack("!i", 0) # content_type

# Encoding the per prediction HTTP response code
# status code and reason phrase set to none
msg += struct.pack("!i", code)
msg += struct.pack("!i", 0) # No code phrase is returned
# Response headers none
msg += struct.pack("!i", 0)
else:
if ts_stream_next is True:
context.set_response_header(idx, "ts_stream_next", "true")
elif context.stopping_criteria:
is_stop = context.stopping_criteria[idx](ret[idx])
if is_stop is not None:
ts_stream_next = "false" if is_stop else "true"
context.set_response_header(idx, "ts_stream_next", ts_stream_next)
elif "true" == context.get_response_headers(idx).get("ts_stream_next"):
context.set_response_header(idx, "ts_stream_next", "false")

content_type = context.get_response_content_type(idx)
if content_type is None or len(content_type) == 0:
msg += struct.pack("!i", 0) # content_type
else:
msg += struct.pack("!i", len(content_type))
msg += content_type.encode("utf-8")

sc, phrase = context.get_response_status(idx)
http_code = sc if sc is not None else 200
http_phrase = phrase if phrase is not None else ""

msg += struct.pack("!i", http_code)
msg += struct.pack("!i", len(http_phrase))
msg += http_phrase.encode("utf-8")
# Response headers
msg += encode_response_headers(context.get_response_headers(idx))

if ret is None:
buf = b"error"
msg += struct.pack("!i", len(buf))
msg += buf
else:
val = ret[idx]
# NOTE: Process bytes/bytearray case before processing the string case.
if isinstance(val, (bytes, bytearray)):
msg += struct.pack("!i", len(val))
msg += val
elif isinstance(val, str):
buf = val.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
elif isinstance(val, torch.Tensor):
buff = io.BytesIO()
torch.save(val, buff)
buff.seek(0)
val_bytes = buff.read()
msg += struct.pack("!i", len(val_bytes))
msg += val_bytes
else:
try:
json_value = json.dumps(val, indent=2).encode("utf-8")
msg += struct.pack("!i", len(json_value))
msg += json_value
except TypeError:
logging.warning("Unable to serialize model output.", exc_info=True)
return create_predict_response(
None, req_id_map, "Unsupported model output data type.", 503
)

msg += struct.pack("!i", -1) # End of list
return msg


def create_load_model_response(code, message):
"""
Create load model response.
Expand Down Expand Up @@ -359,14 +252,3 @@ def _retrieve_input_data(conn):
model_input["value"] = value

return model_input


@deprecated(
version=1.0,
replacement="ts.handler_utils.utils.send_intermediate_predict_response",
)
def send_intermediate_predict_response(ret, req_id_map, message, code, context=None):
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = create_predict_response(ret, req_id_map, message, code, context, True)
context.cl_socket.sendall(msg)
128 changes: 128 additions & 0 deletions ts/protocol/otf_torch_message_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
OTF Codec for functionality requiring importing torch
"""

import io
import json
import logging
import os
import struct
from builtins import bytearray, bytes

import torch

from ts.protocol.otf_message_handler import encode_response_headers
from ts.utils.util import deprecated


def create_predict_response(
ret, req_id_map, message, code, context=None, ts_stream_next=False
):
"""
Create inference response.

:param context:
:param ret:
:param req_id_map:
:param message:
:param code:
:return:
"""
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None

msg = bytearray()
msg += struct.pack("!i", code)

buf = message.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf

for idx in req_id_map:
req_id = req_id_map.get(idx).encode("utf-8")
msg += struct.pack("!i", len(req_id))
msg += req_id

if context is None:
# Encoding Content-Type
msg += struct.pack("!i", 0) # content_type

# Encoding the per prediction HTTP response code
# status code and reason phrase set to none
msg += struct.pack("!i", code)
msg += struct.pack("!i", 0) # No code phrase is returned
# Response headers none
msg += struct.pack("!i", 0)
else:
if ts_stream_next is True:
context.set_response_header(idx, "ts_stream_next", "true")
elif context.stopping_criteria:
is_stop = context.stopping_criteria[idx](ret[idx])
if is_stop is not None:
ts_stream_next = "false" if is_stop else "true"
context.set_response_header(idx, "ts_stream_next", ts_stream_next)
elif "true" == context.get_response_headers(idx).get("ts_stream_next"):
context.set_response_header(idx, "ts_stream_next", "false")

content_type = context.get_response_content_type(idx)
if content_type is None or len(content_type) == 0:
msg += struct.pack("!i", 0) # content_type
else:
msg += struct.pack("!i", len(content_type))
msg += content_type.encode("utf-8")

sc, phrase = context.get_response_status(idx)
http_code = sc if sc is not None else 200
http_phrase = phrase if phrase is not None else ""

msg += struct.pack("!i", http_code)
msg += struct.pack("!i", len(http_phrase))
msg += http_phrase.encode("utf-8")
# Response headers
msg += encode_response_headers(context.get_response_headers(idx))

if ret is None:
buf = b"error"
msg += struct.pack("!i", len(buf))
msg += buf
else:
val = ret[idx]
# NOTE: Process bytes/bytearray case before processing the string case.
if isinstance(val, (bytes, bytearray)):
msg += struct.pack("!i", len(val))
msg += val
elif isinstance(val, str):
buf = val.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
elif isinstance(val, torch.Tensor):
buff = io.BytesIO()
torch.save(val, buff)
buff.seek(0)
val_bytes = buff.read()
msg += struct.pack("!i", len(val_bytes))
msg += val_bytes
else:
try:
json_value = json.dumps(val, indent=2).encode("utf-8")
msg += struct.pack("!i", len(json_value))
msg += json_value
except TypeError:
logging.warning("Unable to serialize model output.", exc_info=True)
return create_predict_response(
None, req_id_map, "Unsupported model output data type.", 503
)

msg += struct.pack("!i", -1) # End of list
return msg


@deprecated(
version=1.0,
replacement="ts.handler_utils.utils.send_intermediate_predict_response",
)
def send_intermediate_predict_response(ret, req_id_map, message, code, context=None):
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = create_predict_response(ret, req_id_map, message, code, context, True)
context.cl_socket.sendall(msg)
2 changes: 1 addition & 1 deletion ts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ts
from ts.context import Context, RequestProcessor
from ts.protocol.otf_message_handler import create_predict_response
from ts.protocol.otf_torch_message_handler import create_predict_response
from ts.utils.util import PredictionException, get_yaml_config

PREDICTION_METRIC = "PredictionTime"
Expand Down
Loading
Loading