From d3670580f1d41b91beeea485dc154aad52537b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Ra=C4=8Dinsk=C3=BD?= Date: Mon, 21 Oct 2024 13:39:49 +0200 Subject: [PATCH] feat: making torchserve usable without torch --- docs/api/ts.protocol.rst | 8 + docs/api/ts.torch_handler.rst | 8 + docs/custom_service.md | 12 +- docs/default_handlers.md | 2 +- .../chat_app/docker/llama_cpp_handler.py | 2 +- ts/handler_utils/utils.py | 2 +- ts/model_service_worker.py | 3 +- ts/protocol/otf_message_handler.py | 120 +--------- ts/protocol/otf_torch_message_handler.py | 128 +++++++++++ ts/service.py | 2 +- .../unit_tests/test_otf_codec_protocol.py | 7 +- ts/torch_handler/abstract_handler.py | 211 ++++++++++++++++++ ts/torch_handler/base_handler.py | 140 ++---------- 13 files changed, 390 insertions(+), 255 deletions(-) create mode 100644 ts/protocol/otf_torch_message_handler.py create mode 100644 ts/torch_handler/abstract_handler.py diff --git a/docs/api/ts.protocol.rst b/docs/api/ts.protocol.rst index 3636a212a5..18d42b5b4a 100644 --- a/docs/api/ts.protocol.rst +++ b/docs/api/ts.protocol.rst @@ -12,6 +12,14 @@ ts.protocol.otf\_message\_handler module :undoc-members: :show-inheritance: +ts.protocol.otf\_torch\_message\_handler module +---------------------------------------- + +.. automodule:: ts.protocol.otf_torch_message_handler + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/docs/api/ts.torch_handler.rst b/docs/api/ts.torch_handler.rst index b709df8bf2..7b99a2af8b 100644 --- a/docs/api/ts.torch_handler.rst +++ b/docs/api/ts.torch_handler.rst @@ -13,6 +13,14 @@ Subpackages Submodules ---------- +ts.torch\_handler.abstract\_handler module +-------------------------------------- + +.. automodule:: ts.torch_handler.abstract_handler + :members: + :undoc-members: + :show-inheritance: + ts.torch\_handler.base\_handler module -------------------------------------- diff --git a/docs/custom_service.md b/docs/custom_service.md index 2a40625503..11df9f1fcd 100755 --- a/docs/custom_service.md +++ b/docs/custom_service.md @@ -23,8 +23,16 @@ Following is applicable to all types of custom handlers * **context** - Is the TorchServe [context](https://github.com/pytorch/serve/blob/master/ts/context.py). You can use following information for customization model_name, model_dir, manifest, batch_size, gpu etc. -### Start with BaseHandler! -[BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py) implements most of the functionality you need. You can derive a new class from it, as shown in the examples and default handlers. Most of the time, you'll only need to override `preprocess` or `postprocess`. +### Start with BaseHandler or AbstractHandler! +[BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py) +implements most of the functionality you need for torch models. +You can derive a new class from it, as shown in the examples and default handlers. +Most of the time, you'll only need to override `preprocess` or `postprocess`. + +[AbstractHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/abstract_handler.py) +implements most of the handling functionality without being torch-specific. You need to implement the model loading +yourself, but it lets you use TorchServe and its great features without needing to install torch, so it's useful for +scikit-learn or tensorflow models in case you need something flexible with executing python during the serving. ### Custom handler with `module` level entry point diff --git a/docs/default_handlers.md b/docs/default_handlers.md index cc365c74cc..3fc9366902 100644 --- a/docs/default_handlers.md +++ b/docs/default_handlers.md @@ -48,7 +48,7 @@ For a more comprehensive list of available handlers make sure to check out the [ - [object_detector](https://github.com/pytorch/serve/tree/master/examples/object_detector/index_to_name.json) ### Contributing -We welcome new contributed handlers, if your usecase isn't covered by one of the existing default handlers please follow the below steps to contribute it +We welcome new contributed handlers, if your use-case isn't covered by one of the existing default handlers please follow the below steps to contribute it 1. Write a new class derived from [BaseHandler](https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py). Add it as a separate file in `ts/torch_handler/` 2. Update `model-archiver/model_packaging.py` to add in your classes name 3. Run and update the unit tests in [unit_tests](https://github.com/pytorch/serve/tree/master/ts/torch_handler/unit_tests). As always, make sure to run [torchserve_sanity.py](https://github.com/pytorch/serve/tree/master/torchserve_sanity.py) before submitting. diff --git a/examples/LLM/llama/chat_app/docker/llama_cpp_handler.py b/examples/LLM/llama/chat_app/docker/llama_cpp_handler.py index 09de9b503d..a4f4d292ea 100644 --- a/examples/LLM/llama/chat_app/docker/llama_cpp_handler.py +++ b/examples/LLM/llama/chat_app/docker/llama_cpp_handler.py @@ -5,7 +5,7 @@ import torch from llama_cpp import Llama -from ts.protocol.otf_message_handler import send_intermediate_predict_response +from ts.protocol.otf_torch_message_handler import send_intermediate_predict_response from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py index 3361513c40..be989cd0f2 100644 --- a/ts/handler_utils/utils.py +++ b/ts/handler_utils/utils.py @@ -2,7 +2,7 @@ import os from ts.context import Context -from ts.protocol.otf_message_handler import create_predict_response +from ts.protocol.otf_torch_message_handler import create_predict_response def import_class(class_name: str, module_prefix=None): diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 819a473add..518f00f1d5 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -1,5 +1,6 @@ """ -ModelServiceWorker is the worker that is started by the MMS front-end. +class TorchModelServiceWorker(object): + is the worker that is started by the MMS front-end. Communication message format: binary encoding """ diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index 29de350e15..91da0f475c 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -2,18 +2,13 @@ OTF Codec """ -import io import json import logging import os import struct import sys import time -from builtins import bytearray, bytes - -import torch - -from ts.utils.util import deprecated +from builtins import bytearray bool_size = 1 int_size = 4 @@ -53,108 +48,6 @@ def encode_response_headers(resp_hdr_map): return msg -def create_predict_response( - ret, req_id_map, message, code, context=None, ts_stream_next=False -): - """ - Create inference response. - - :param context: - :param ret: - :param req_id_map: - :param message: - :param code: - :return: - """ - if str(os.getenv("LOCAL_RANK", 0)) != "0": - return None - - msg = bytearray() - msg += struct.pack("!i", code) - - buf = message.encode("utf-8") - msg += struct.pack("!i", len(buf)) - msg += buf - - for idx in req_id_map: - req_id = req_id_map.get(idx).encode("utf-8") - msg += struct.pack("!i", len(req_id)) - msg += req_id - - if context is None: - # Encoding Content-Type - msg += struct.pack("!i", 0) # content_type - - # Encoding the per prediction HTTP response code - # status code and reason phrase set to none - msg += struct.pack("!i", code) - msg += struct.pack("!i", 0) # No code phrase is returned - # Response headers none - msg += struct.pack("!i", 0) - else: - if ts_stream_next is True: - context.set_response_header(idx, "ts_stream_next", "true") - elif context.stopping_criteria: - is_stop = context.stopping_criteria[idx](ret[idx]) - if is_stop is not None: - ts_stream_next = "false" if is_stop else "true" - context.set_response_header(idx, "ts_stream_next", ts_stream_next) - elif "true" == context.get_response_headers(idx).get("ts_stream_next"): - context.set_response_header(idx, "ts_stream_next", "false") - - content_type = context.get_response_content_type(idx) - if content_type is None or len(content_type) == 0: - msg += struct.pack("!i", 0) # content_type - else: - msg += struct.pack("!i", len(content_type)) - msg += content_type.encode("utf-8") - - sc, phrase = context.get_response_status(idx) - http_code = sc if sc is not None else 200 - http_phrase = phrase if phrase is not None else "" - - msg += struct.pack("!i", http_code) - msg += struct.pack("!i", len(http_phrase)) - msg += http_phrase.encode("utf-8") - # Response headers - msg += encode_response_headers(context.get_response_headers(idx)) - - if ret is None: - buf = b"error" - msg += struct.pack("!i", len(buf)) - msg += buf - else: - val = ret[idx] - # NOTE: Process bytes/bytearray case before processing the string case. - if isinstance(val, (bytes, bytearray)): - msg += struct.pack("!i", len(val)) - msg += val - elif isinstance(val, str): - buf = val.encode("utf-8") - msg += struct.pack("!i", len(buf)) - msg += buf - elif isinstance(val, torch.Tensor): - buff = io.BytesIO() - torch.save(val, buff) - buff.seek(0) - val_bytes = buff.read() - msg += struct.pack("!i", len(val_bytes)) - msg += val_bytes - else: - try: - json_value = json.dumps(val, indent=2).encode("utf-8") - msg += struct.pack("!i", len(json_value)) - msg += json_value - except TypeError: - logging.warning("Unable to serialize model output.", exc_info=True) - return create_predict_response( - None, req_id_map, "Unsupported model output data type.", 503 - ) - - msg += struct.pack("!i", -1) # End of list - return msg - - def create_load_model_response(code, message): """ Create load model response. @@ -359,14 +252,3 @@ def _retrieve_input_data(conn): model_input["value"] = value return model_input - - -@deprecated( - version=1.0, - replacement="ts.handler_utils.utils.send_intermediate_predict_response", -) -def send_intermediate_predict_response(ret, req_id_map, message, code, context=None): - if str(os.getenv("LOCAL_RANK", 0)) != "0": - return None - msg = create_predict_response(ret, req_id_map, message, code, context, True) - context.cl_socket.sendall(msg) diff --git a/ts/protocol/otf_torch_message_handler.py b/ts/protocol/otf_torch_message_handler.py new file mode 100644 index 0000000000..125f0e254b --- /dev/null +++ b/ts/protocol/otf_torch_message_handler.py @@ -0,0 +1,128 @@ +""" +OTF Codec for functionality requiring importing torch +""" + +import io +import json +import logging +import os +import struct +from builtins import bytearray, bytes + +import torch + +from ts.protocol.otf_message_handler import encode_response_headers +from ts.utils.util import deprecated + + +def create_predict_response( + ret, req_id_map, message, code, context=None, ts_stream_next=False +): + """ + Create inference response. + + :param context: + :param ret: + :param req_id_map: + :param message: + :param code: + :return: + """ + if str(os.getenv("LOCAL_RANK", 0)) != "0": + return None + + msg = bytearray() + msg += struct.pack("!i", code) + + buf = message.encode("utf-8") + msg += struct.pack("!i", len(buf)) + msg += buf + + for idx in req_id_map: + req_id = req_id_map.get(idx).encode("utf-8") + msg += struct.pack("!i", len(req_id)) + msg += req_id + + if context is None: + # Encoding Content-Type + msg += struct.pack("!i", 0) # content_type + + # Encoding the per prediction HTTP response code + # status code and reason phrase set to none + msg += struct.pack("!i", code) + msg += struct.pack("!i", 0) # No code phrase is returned + # Response headers none + msg += struct.pack("!i", 0) + else: + if ts_stream_next is True: + context.set_response_header(idx, "ts_stream_next", "true") + elif context.stopping_criteria: + is_stop = context.stopping_criteria[idx](ret[idx]) + if is_stop is not None: + ts_stream_next = "false" if is_stop else "true" + context.set_response_header(idx, "ts_stream_next", ts_stream_next) + elif "true" == context.get_response_headers(idx).get("ts_stream_next"): + context.set_response_header(idx, "ts_stream_next", "false") + + content_type = context.get_response_content_type(idx) + if content_type is None or len(content_type) == 0: + msg += struct.pack("!i", 0) # content_type + else: + msg += struct.pack("!i", len(content_type)) + msg += content_type.encode("utf-8") + + sc, phrase = context.get_response_status(idx) + http_code = sc if sc is not None else 200 + http_phrase = phrase if phrase is not None else "" + + msg += struct.pack("!i", http_code) + msg += struct.pack("!i", len(http_phrase)) + msg += http_phrase.encode("utf-8") + # Response headers + msg += encode_response_headers(context.get_response_headers(idx)) + + if ret is None: + buf = b"error" + msg += struct.pack("!i", len(buf)) + msg += buf + else: + val = ret[idx] + # NOTE: Process bytes/bytearray case before processing the string case. + if isinstance(val, (bytes, bytearray)): + msg += struct.pack("!i", len(val)) + msg += val + elif isinstance(val, str): + buf = val.encode("utf-8") + msg += struct.pack("!i", len(buf)) + msg += buf + elif isinstance(val, torch.Tensor): + buff = io.BytesIO() + torch.save(val, buff) + buff.seek(0) + val_bytes = buff.read() + msg += struct.pack("!i", len(val_bytes)) + msg += val_bytes + else: + try: + json_value = json.dumps(val, indent=2).encode("utf-8") + msg += struct.pack("!i", len(json_value)) + msg += json_value + except TypeError: + logging.warning("Unable to serialize model output.", exc_info=True) + return create_predict_response( + None, req_id_map, "Unsupported model output data type.", 503 + ) + + msg += struct.pack("!i", -1) # End of list + return msg + + +@deprecated( + version=1.0, + replacement="ts.handler_utils.utils.send_intermediate_predict_response", +) +def send_intermediate_predict_response(ret, req_id_map, message, code, context=None): + if str(os.getenv("LOCAL_RANK", 0)) != "0": + return None + msg = create_predict_response(ret, req_id_map, message, code, context, True) + context.cl_socket.sendall(msg) diff --git a/ts/service.py b/ts/service.py index 1a4bcbbc9c..b86a21a01a 100644 --- a/ts/service.py +++ b/ts/service.py @@ -8,7 +8,7 @@ import ts from ts.context import Context, RequestProcessor -from ts.protocol.otf_message_handler import create_predict_response +from ts.protocol.otf_torch_message_handler import create_predict_response from ts.utils.util import PredictionException, get_yaml_config PREDICTION_METRIC = "PredictionTime" diff --git a/ts/tests/unit_tests/test_otf_codec_protocol.py b/ts/tests/unit_tests/test_otf_codec_protocol.py index 5df1c4644a..c32a7a5367 100644 --- a/ts/tests/unit_tests/test_otf_codec_protocol.py +++ b/ts/tests/unit_tests/test_otf_codec_protocol.py @@ -12,6 +12,7 @@ import pytest import ts.protocol.otf_message_handler as codec +import ts.protocol.otf_torch_message_handler as codec_torch @pytest.fixture() @@ -191,7 +192,7 @@ def test_create_load_model_response(self): assert msg == b"\x00\x00\x00\xc8\x00\x00\x00\x0cmodel_loaded\xff\xff\xff\xff" def test_create_predict_response(self): - msg = codec.create_predict_response(["OK"], {0: "request_id"}, "success", 200) + msg = codec_torch.create_predict_response(["OK"], {0: "request_id"}, "success", 200) assert ( msg == b"\x00\x00\x00\xc8\x00\x00\x00\x07success\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00" @@ -199,7 +200,7 @@ def test_create_predict_response(self): ) def test_create_predict_response_with_error(self): - msg = codec.create_predict_response(None, {0: "request_id"}, "failed", 200) + msg = codec_torch.create_predict_response(None, {0: "request_id"}, "failed", 200) assert ( msg @@ -225,7 +226,7 @@ def test_create_predict_response_with_context(self): ctx.stopping_criteria = {0: lambda _: True, 1: lambda _: False} ctx.request_processor = {0: RequestProcessor({}), 1: RequestProcessor({})} - msg = codec.create_predict_response( + msg = codec_torch.create_predict_response( ["OK", "NOT OK"], {0: "request_0", 1: "request_1"}, "success", diff --git a/ts/torch_handler/abstract_handler.py b/ts/torch_handler/abstract_handler.py new file mode 100644 index 0000000000..978f11203f --- /dev/null +++ b/ts/torch_handler/abstract_handler.py @@ -0,0 +1,211 @@ +import abc +import logging +import time + +from ts.handler_utils.timer import timed + +logger = logging.getLogger(__name__) + + +class AbstractHandler(abc.ABC): + """ + Base default handler to load the model + Also, provides handle method per torch serve custom model specification + """ + + def __init__(self): + self.device = None + self.context = None + self.explain = False + + @abc.abstractmethod + def initialize(self, context): + """ + Initialize function loads the model and initializes the model object. + + Args: + context (context): It is a JSON Object containing information + pertaining to the model artifacts parameters. + + Raises: + RuntimeError: Raises the Runtime error when the model.py is missing + + """ + pass + + @abc.abstractmethod + def as_tensor(self, data): + """ + Convert data to tensor consumable by the underlying model. + Used for preprocessing the request. + + Args : + data (list): List of the data from the request input. + + Returns: + tensor: Returns the tensor data of the input + """ + pass + + @timed + def preprocess(self, data): + """ + Preprocess function to convert the request input to a tensor(Torchserve supported format). + The user needs to override to customize the pre-processing + + Args : + data (list): List of the data from the request input. + + Returns: + tensor: Returns the tensor data of the input + """ + return self.as_tensor(data) + + @abc.abstractmethod + @timed + def inference(self, data, *args, **kwargs): + pass + + @timed + def postprocess(self, data): + """ + The post process function makes use of the output from the inference and converts into a + Torchserve supported response output. + + Args: + data (numpy array-like structure): The tensor received from the prediction output of the model. + + Returns: + List: The post process function returns a list of the predicted output. + """ + + return data.tolist() + + def handle(self, data, context): + """ + Entry point for default handler. It takes the data from the input request and returns + the predicted outcome for the input. + + Args: + data (list): The input data that needs to be made a prediction request on. + context (Context): It is a JSON Object containing information pertaining to + the model artifacts parameters. + + Returns: + list: Returns a list of dictionary with the predicted response. + """ + + # It can be used for pre or post processing if needed as additional request + # information is available in context + start_time = time.time() + + self.context = context + metrics = self.context.metrics + + is_profiler_enabled = self.profiler_enabled() + if is_profiler_enabled: + output, _ = self.infer_with_profiler(data=data, context=context) + else: + if self._is_describe(): + output = [self.describe_handle()] + else: + data_preprocess = self.preprocess(data) + + if not self._is_explain(): + output = self.inference(data_preprocess) + output = self.postprocess(output) + else: + output = self.explain_handle(data_preprocess, data) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output + + def explain_handle(self, data_preprocess, raw_data): + """ + Captum explanations handler + + Args: + data_preprocess (numpy array-like structure): Preprocessed data to be used for captum + raw_data (list): The unprocessed data to get target from the request + + Returns: + dict : A dictionary response with the explanations response. + """ + output_explain = None + inputs = None + target = 0 + + logger.info("Calculating Explanations") + row = raw_data[0] + if isinstance(row, dict): + logger.info("Getting data and target") + inputs = row.get("data") or row.get("body") + target = row.get("target") + if not target: + target = 0 + + output_explain = self.get_insights(data_preprocess, inputs, target) + return output_explain + + @abc.abstractmethod + def get_insights(self, tensor_data, _, target=0): + pass + + def _is_explain(self): + if self.context and self.context.get_request_header(0, "explain"): + if self.context.get_request_header(0, "explain") == "True": + self.explain = True + return True + return False + + def _is_describe(self): + if self.context and self.context.get_request_header(0, "describe"): + if self.context.get_request_header(0, "describe") == "True": + return True + return False + + @abc.abstractmethod + def describe_handle(self): + """Customized describe handler + + Returns: + dict : A dictionary response. + """ + pass + + def get_device(self): + """Get device + + Returns: + string : self device + """ + return self.device + + @abc.abstractmethod + def profiler_enabled(self): + """ + Return true if profiler is enabled + + Returns: + bool: true if profiler is enabled + """ + pass + + @abc.abstractmethod + def infer_with_profiler(self, data, context): + """ + Custom method to for handling the inference with profiler + + Args: + data (list): The input data that needs to be made a prediction request on. + context (Context): It is a JSON Object containing information pertaining to + the model artifacts parameters. + + Returns: + output : Returns a list of dictionary with the predicted response. + prof: profiler object + """ + pass diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index fa4be5841c..8b5a070b93 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -13,6 +13,7 @@ import torch from ts.handler_utils.timer import timed +from .abstract_handler import AbstractHandler from ..utils.util import ( check_valid_pt2_backend, @@ -116,22 +117,20 @@ def setup_ort_session(model_pt_path, map_location): return ort_session -class BaseHandler(abc.ABC): +class BaseHandler(AbstractHandler, abc.ABC): """ Base default handler to load torchscript or eager mode [state_dict] models Also, provides handle method per torch serve custom model specification """ def __init__(self): + super().__init__() self.model = None self.mapping = None - self.device = None self.initialized = False - self.context = None self.model_pt_path = None self.manifest = None self.map_location = None - self.explain = False self.target = 0 self.profiler_args = {} @@ -368,8 +367,7 @@ def _use_torch_export_aot_compile(self): ) return torch_export_aot_compile - @timed - def preprocess(self, data): + def as_tensor(self, data): """ Preprocess function to convert the request input to a tensor(Torchserve supported format). The user needs to override to customize the pre-processing @@ -401,70 +399,17 @@ def inference(self, data, *args, **kwargs): results = self.model(marshalled_data, *args, **kwargs) return results - @timed - def postprocess(self, data): - """ - The post process function makes use of the output from the inference and converts into a - Torchserve supported response output. - - Args: - data (Torch Tensor): The torch tensor received from the prediction output of the model. - - Returns: - List: The post process function returns a list of the predicted output. - """ - - return data.tolist() - - def handle(self, data, context): - """Entry point for default handler. It takes the data from the input request and returns - the predicted outcome for the input. - - Args: - data (list): The input data that needs to be made a prediction request on. - context (Context): It is a JSON Object containing information pertaining to - the model artifacts parameters. - - Returns: - list : Returns a list of dictionary with the predicted response. - """ - - # It can be used for pre or post processing if needed as additional request - # information is available in context - start_time = time.time() - - self.context = context - metrics = self.context.metrics - - is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) - if is_profiler_enabled: - if PROFILER_AVAILABLE: - if self.manifest is None: - # profiler will use to get the model name - self.manifest = context.manifest - output, _ = self._infer_with_profiler(data=data) - else: - raise RuntimeError( - "Profiler is enabled but current version of torch does not support." - "Install torch>=1.8.1 to use profiler." - ) + def infer_with_profiler(self, data, context): + if PROFILER_AVAILABLE: + if self.manifest is None: + # profiler will use to get the model name + self.manifest = context.manifest + output, _ = self._infer_with_profiler(data=data) else: - if self._is_describe(): - output = [self.describe_handle()] - else: - data_preprocess = self.preprocess(data) - - if not self._is_explain(): - output = self.inference(data_preprocess) - output = self.postprocess(output) - else: - output = self.explain_handle(data_preprocess, data) - - stop_time = time.time() - metrics.add_time( - "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" - ) - return output + raise RuntimeError( + "Profiler is enabled but current version of torch does not support." + "Install torch>=1.8.1 to use profiler." + ) def _infer_with_profiler(self, data): """Custom method to generate pytorch profiler traces for preprocess/inference/postprocess @@ -514,60 +459,3 @@ def _infer_with_profiler(self, data): logger.info(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) return output, prof - - def explain_handle(self, data_preprocess, raw_data): - """Captum explanations handler - - Args: - data_preprocess (Torch Tensor): Preprocessed data to be used for captum - raw_data (list): The unprocessed data to get target from the request - - Returns: - dict : A dictionary response with the explanations response. - """ - output_explain = None - inputs = None - target = 0 - - logger.info("Calculating Explanations") - row = raw_data[0] - if isinstance(row, dict): - logger.info("Getting data and target") - inputs = row.get("data") or row.get("body") - target = row.get("target") - if not target: - target = 0 - - output_explain = self.get_insights(data_preprocess, inputs, target) - return output_explain - - def _is_explain(self): - if self.context and self.context.get_request_header(0, "explain"): - if self.context.get_request_header(0, "explain") == "True": - self.explain = True - return True - return False - - def _is_describe(self): - if self.context and self.context.get_request_header(0, "describe"): - if self.context.get_request_header(0, "describe") == "True": - return True - return False - - def describe_handle(self): - """Customized describe handler - - Returns: - dict : A dictionary response. - """ - # pylint: disable=unnecessary-pass - pass - # pylint: enable=unnecessary-pass - - def get_device(self): - """Get device - - Returns: - string : self device - """ - return self.device