diff --git a/docker/dockerfiles/Dockerfile.onnx.cpu b/docker/dockerfiles/Dockerfile.onnx.cpu index 23a88e297..901e94157 100644 --- a/docker/dockerfiles/Dockerfile.onnx.cpu +++ b/docker/dockerfiles/Dockerfile.onnx.cpu @@ -40,6 +40,7 @@ RUN pip3 install --upgrade pip && pip3 install \ -r requirements.transformers.txt \ jupyterlab \ wheel>=0.38.0 \ + setuptools>=65.5.1 \ --upgrade \ && rm -rf ~/.cache/pip @@ -74,4 +75,5 @@ ENV API_LOGGING_ENABLED=True ENV CORE_MODEL_SAM2_ENABLED=True ENV CORE_MODEL_OWLV2_ENABLED=True -ENTRYPOINT uvicorn cpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT \ No newline at end of file +RUN pip install watchdog[watchmedo] +ENTRYPOINT watchmedo auto-restart --directory=/app/inference --pattern=*.py --recursive -- uvicorn cpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT \ No newline at end of file diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py new file mode 100644 index 000000000..b78364382 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Callable, List + +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult + + +class BaseOCRModel(ABC): + + def __init__(self, model_manager, api_key): + self.model_manager = model_manager + self.api_key = api_key + + @abstractmethod + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + pass diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py new file mode 100644 index 000000000..573cb69c0 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py @@ -0,0 +1,64 @@ +from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.core_steps.common.utils import load_core_model +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult +from typing import Callable, List + +from .base import BaseOCRModel + + +class DoctrOCRModel(BaseOCRModel): + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + if step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images, post_process_result) + elif step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images, post_process_result) + + def run_locally( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for single_image in images: + inference_request = DoctrOCRInferenceRequest( + image=single_image.to_inference_format(numpy_preferred=True), + api_key=self.api_key, + ) + doctr_model_id = load_core_model( + model_manager=self.model_manager, + inference_request=inference_request, + core_model="doctr", + ) + result = self.model_manager.infer_from_request_sync( + doctr_model_id, inference_request + ) + predictions.append(result.model_dump()) + return post_process_result(images, predictions) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + raise NotImplementedError( + "Remote execution is not implemented for DoctrOCRModel." + ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py new file mode 100644 index 000000000..f249b8cd5 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py @@ -0,0 +1,64 @@ +# models/google_cloud_vision.py + +from .base import BaseOCRModel +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from typing import Optional +import requests + + +class GoogleCloudVisionOCRModel(BaseOCRModel): + def __init__( + self, model_manager, api_key: Optional[str], google_cloud_api_key: str + ): + super().__init__(model_manager, api_key) + self.google_cloud_api_key = google_cloud_api_key + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result, + ): + predictions = [] + for image_data in images: + encoded_image = image_data.base64_image + url = ( + f"https://vision.googleapis.com/v1/images:annotate" + f"?key={self.google_cloud_api_key}" + ) + + payload = { + "requests": [ + { + "image": {"content": encoded_image}, + "features": [{"type": "TEXT_DETECTION"}], + } + ] + } + # Send the request + response = requests.post(url, json=payload) + if response.status_code == 200: + result = response.json() + text_annotations = result["responses"][0].get( + "textAnnotations", + [], + ) + if text_annotations: + text = text_annotations[0]["description"] + else: + text = "" + else: + error_info = response.json().get("error", {}) + message = error_info.get("message", response.text) + raise Exception( + f"Google Cloud Vision API request failed: {message}", + ) + prediction = {"result": text} + predictions.append(prediction) + return post_process_result(images, predictions) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py new file mode 100644 index 000000000..257d24983 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py @@ -0,0 +1,83 @@ +from .base import BaseOCRModel +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from typing import Optional, List, Callable +from inference.core.workflows.prototypes.block import BlockResult + +import requests +import json +import base64 + + +class MathpixOCRModel(BaseOCRModel): + def __init__( + self, + model_manager, + api_key: Optional[str], + mathpix_app_id: str, + mathpix_app_key: str, + ): + super().__init__(model_manager, api_key) + self.mathpix_app_id = mathpix_app_id + self.mathpix_app_key = mathpix_app_key + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for image_data in images: + # Decode base64 image to bytes + image_bytes = base64.b64decode(image_data.base64_image) + + # Prepare the request + url = "https://api.mathpix.com/v3/text" + headers = { + "app_id": self.mathpix_app_id, + "app_key": self.mathpix_app_key, + } + data = { + "options_json": json.dumps( + { + "math_inline_delimiters": ["$", "$"], + "rm_spaces": True, + } + ) + } + files = {"file": ("image.jpg", image_bytes, "image/jpeg")} + + # Send the request + response = requests.post( + url, + headers=headers, + data=data, + files=files, + ) + + if response.status_code == 200: + result = response.json() + # Extract the text result + text = result.get("text", "") + else: + error_info = response.json().get("error", {}) + message = error_info.get("message", response.text) + detailed_message = error_info.get("detail", "") + + raise Exception( + f"Mathpix API request failed: {message} \n\n" + f"Detailed: {detailed_message}" + ) + + prediction = {"result": text} + predictions.append(prediction) + + return post_process_result(images, predictions) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py new file mode 100644 index 000000000..19d98fb32 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py @@ -0,0 +1,65 @@ +from typing import Callable, List + +from inference.core.entities.requests.trocr import TrOCRInferenceRequest +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.core_steps.common.utils import load_core_model +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult + +from .base import BaseOCRModel + + +class TrOCRModel(BaseOCRModel): + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + if step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images, post_process_result) + elif step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images, post_process_result) + + def run_locally( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for single_image in images: + inference_request = TrOCRInferenceRequest( + image=single_image.to_inference_format(numpy_preferred=True), + api_key=self.api_key, + ) + trocr_model_id = load_core_model( + model_manager=self.model_manager, + inference_request=inference_request, + core_model="trocr", + ) + result = self.model_manager.infer_from_request_sync( + trocr_model_id, inference_request + ) + predictions.append(result.model_dump()) + return post_process_result(images, predictions) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + raise NotImplementedError( + "Remote execution is not implemented for TrOCRModel.", + ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index 0b98c263d..26f77e452 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -2,18 +2,11 @@ from pydantic import ConfigDict, Field -from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest -from inference.core.env import ( - HOSTED_CORE_MODEL_URL, - LOCAL_INFERENCE_API_URL, - WORKFLOWS_REMOTE_API_TARGET, - WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, - WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, -) from inference.core.managers.base import ModelManager -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.core_steps.common.utils import ( - load_core_model, remove_unexpected_keys_from_dictionary, ) from inference.core.workflows.execution_engine.constants import ( @@ -39,31 +32,77 @@ WorkflowBlock, WorkflowBlockManifest, ) -from inference_sdk import InferenceConfiguration, InferenceHTTPClient + +from .models.base import BaseOCRModel +from .models.doctr import DoctrOCRModel +from .models.trocr import TrOCRModel +from .models.google_cloud_vision import GoogleCloudVisionOCRModel +from .models.mathpix import MathpixOCRModel + +SHORT_DESCRIPTION = ( + "Extract text from an image using optical character recognition (OCR)." +) LONG_DESCRIPTION = """ - Retrieve the characters in an image using Optical Character Recognition (OCR). +Retrieve the characters in an image using Optical Character Recognition (OCR). This block returns the text within an image. -You may want to use this block in combination with a detections-based block (i.e. -ObjectDetectionBlock). An object detection model could isolate specific regions from an -image (i.e. a shipping container ID in a logistics use case) for further processing. -You can then use a DynamicCropBlock to crop the region of interest before running OCR. +You may want to use this block in combination with a detections-based block +(i.e. ObjectDetectionBlock). An object detection model could isolate specific +regions from an image (i.e. a shipping container ID in a logistics use case) +for further processing. You can then use a DynamicCropBlock to crop the region +of interest before running OCR. -Using a detections model then cropping detections allows you to isolate your analysis -on particular regions of an image. +Using a detections model then cropping detections allows you to isolate your +analysis on particular regions of an image. """ -EXPECTED_OUTPUT_KEYS = {"result", "parent_id", "root_parent_id", "prediction_type"} +EXPECTED_OUTPUT_KEYS = { + "result", + "parent_id", + "root_parent_id", + "prediction_type", +} + +# Registry of available models +MODEL_REGISTRY = { + "doctr": { + "class": DoctrOCRModel, + "description": "DocTR", + "required_fields": [], + }, + "trocr": { + "class": TrOCRModel, + "description": "TrOCR", + "required_fields": [], + }, + "google-cloud-vision": { + "class": GoogleCloudVisionOCRModel, + "description": "Google Cloud Vision OCR", + "required_fields": ["google_cloud_api_key"], + }, + "mathpix": { + "class": MathpixOCRModel, + "description": "Mathpix Convert API", + "required_fields": ["mathpix_app_id", "mathpix_app_key"], + }, +} + +ModelLiteral = Literal[ + "doctr", + "trocr", + "google-cloud-vision", + "mathpix", +] class BlockManifest(WorkflowBlockManifest): model_config = ConfigDict( json_schema_extra={ - "name": "OCR Model", + "name": "Text Recognition (OCR)", "version": "v1", - "short_description": "Extract text from an image using optical character recognition.", + "short_description": SHORT_DESCRIPTION, "long_description": LONG_DESCRIPTION, "license": "Apache-2.0", "block_type": "model", @@ -71,7 +110,26 @@ class BlockManifest(WorkflowBlockManifest): ) type: Literal["roboflow_core/ocr_model@v1", "OCRModel"] name: str = Field(description="Unique name of step in workflows") - images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField + images: Union[ + WorkflowImageSelector, + StepOutputImageSelector, + ] = ImageInputField + model: ModelLiteral = Field( + default="doctr", + description="The OCR model to use.", + ) + google_cloud_api_key: Optional[str] = Field( + default=None, + description="API key for Google Cloud Vision.", + ) + mathpix_app_id: Optional[str] = Field( + default=None, + description="App ID for Mathpix API", + ) + mathpix_app_key: Optional[str] = Field( + default=None, + description="App Key for Mathpix API", + ) @classmethod def accepts_batch_input(cls) -> bool: @@ -83,7 +141,10 @@ def describe_outputs(cls) -> List[OutputDefinition]: OutputDefinition(name="result", kind=[STRING_KIND]), OutputDefinition(name="parent_id", kind=[PARENT_ID_KIND]), OutputDefinition(name="root_parent_id", kind=[PARENT_ID_KIND]), - OutputDefinition(name="prediction_type", kind=[PREDICTION_TYPE_KIND]), + OutputDefinition( + name="prediction_type", + kind=[PREDICTION_TYPE_KIND], + ), ] @classmethod @@ -92,8 +153,6 @@ def get_execution_engine_compatibility(cls) -> Optional[str]: class OCRModelBlockV1(WorkflowBlock): - # TODO: we need data model for OCR predictions - def __init__( self, model_manager: ModelManager, @@ -115,69 +174,43 @@ def get_manifest(cls) -> Type[WorkflowBlockManifest]: def run( self, images: Batch[WorkflowImageData], + model: str, + google_cloud_api_key: Optional[str] = None, + mathpix_app_id: Optional[str] = None, + mathpix_app_key: Optional[str] = None, ) -> BlockResult: - if self._step_execution_mode is StepExecutionMode.LOCAL: - return self.run_locally(images=images) - elif self._step_execution_mode is StepExecutionMode.REMOTE: - return self.run_remotely(images=images) - else: - raise ValueError( - f"Unknown step execution mode: {self._step_execution_mode}" - ) - - def run_locally( - self, - images: Batch[WorkflowImageData], - ) -> BlockResult: - predictions = [] - for single_image in images: - inference_request = DoctrOCRInferenceRequest( - image=single_image.to_inference_format(numpy_preferred=True), - api_key=self._api_key, - ) - doctr_model_id = load_core_model( - model_manager=self._model_manager, - inference_request=inference_request, - core_model="doctr", - ) - result = self._model_manager.infer_from_request_sync( - doctr_model_id, inference_request - ) - predictions.append(result.model_dump()) - return self._post_process_result( - predictions=predictions, + ocr_model = self._get_model_instance( + model=model, + google_cloud_api_key=google_cloud_api_key, + mathpix_app_id=mathpix_app_id, + mathpix_app_key=mathpix_app_key, + ) + return ocr_model.run( images=images, + step_execution_mode=self._step_execution_mode, + post_process_result=self._post_process_result, ) - def run_remotely( + def _get_model_instance( self, - images: Batch[WorkflowImageData], - ) -> BlockResult: - api_url = ( - LOCAL_INFERENCE_API_URL - if WORKFLOWS_REMOTE_API_TARGET != "hosted" - else HOSTED_CORE_MODEL_URL - ) - client = InferenceHTTPClient( - api_url=api_url, + model: str, + **kwargs, + ) -> BaseOCRModel: + model_info = MODEL_REGISTRY.get(model) + if not model_info: + raise ValueError(f"Unknown model: {model}") + model_class = model_info["class"] + required_fields = { + field: kwargs.get(field) + for field in model_info.get( + "required_fields", + [], + ) + } + return model_class( + model_manager=self._model_manager, api_key=self._api_key, - ) - if WORKFLOWS_REMOTE_API_TARGET == "hosted": - client.select_api_v0() - configuration = InferenceConfiguration( - max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, - max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, - ) - client.configure(configuration) - non_empty_inference_images = [i.numpy_image for i in images] - predictions = client.ocr_image( - inference_input=non_empty_inference_images, - ) - if len(images) == 1: - predictions = [predictions] - return self._post_process_result( - predictions=predictions, - images=images, + **required_fields, ) def _post_process_result(