From f659d10a8ea14eeca0764cfc0cbd3e1784b005bb Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 00:27:04 +0900 Subject: [PATCH 1/8] OCR model options --- .../models/foundation/ocr/models/base.py | 27 ++++ .../models/foundation/ocr/models/doctr.py | 64 ++++++++++ .../core_steps/models/foundation/ocr/v1.py | 120 ++++++++---------- 3 files changed, 146 insertions(+), 65 deletions(-) create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/base.py create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py new file mode 100644 index 000000000..e13db5153 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List + +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult + + +class BaseOCRModel(ABC): + + def __init__(self, model_manager, api_key): + self.model_manager = model_manager + self.api_key = api_key + + @abstractmethod + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + pass diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py new file mode 100644 index 000000000..7e3bea591 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py @@ -0,0 +1,64 @@ +from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.utils import load_core_model +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult +from typing import Callable, List + +from .base import BaseOCRModel + + +class DoctrOCRModel(BaseOCRModel): + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + if step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images, post_process_result) + elif step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images, post_process_result) + else: + raise ValueError(f"Unknown step execution mode: {step_execution_mode}") + + def run_locally( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for single_image in images: + inference_request = DoctrOCRInferenceRequest( + image=single_image.to_inference_format(numpy_preferred=True), + api_key=self.api_key, + ) + doctr_model_id = load_core_model( + model_manager=self.model_manager, + inference_request=inference_request, + core_model="doctr", + ) + result = self.model_manager.infer_from_request_sync( + doctr_model_id, inference_request + ) + predictions.append(result.model_dump()) + return post_process_result(images, predictions) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + raise NotImplementedError( + "Remote execution is not implemented for DoctrOCRModel." + ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index 0b98c263d..acd31449a 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -1,8 +1,8 @@ -from typing import List, Literal, Optional, Type, Union +from typing import Callable, Dict, List, Literal, Optional, Type, Union -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator -from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.entities.requests.inference import LMMInferenceRequest from inference.core.env import ( HOSTED_CORE_MODEL_URL, LOCAL_INFERENCE_API_URL, @@ -13,7 +13,6 @@ from inference.core.managers.base import ModelManager from inference.core.workflows.core_steps.common.entities import StepExecutionMode from inference.core.workflows.core_steps.common.utils import ( - load_core_model, remove_unexpected_keys_from_dictionary, ) from inference.core.workflows.execution_engine.constants import ( @@ -39,7 +38,10 @@ WorkflowBlock, WorkflowBlockManifest, ) -from inference_sdk import InferenceConfiguration, InferenceHTTPClient + +from .models.base import BaseOCRModel +from .models.doctr import DoctrOCRModel +from .models.trocr import TrOCRModel # Added import for TrOCRModel LONG_DESCRIPTION = """ Retrieve the characters in an image using Optical Character Recognition (OCR). @@ -57,11 +59,27 @@ EXPECTED_OUTPUT_KEYS = {"result", "parent_id", "root_parent_id", "prediction_type"} +# Registry of available models +MODEL_REGISTRY = { + "doctr": { + "class": DoctrOCRModel, + "description": "Doctr OCR Model", + "required_fields": [], + }, + "trocr": { + "class": TrOCRModel, + "description": "TrOCR Model", + "required_fields": [], + }, +} + +ModelLiteral = Literal["doctr", "trocr"] # Updated to include 'trocr' + class BlockManifest(WorkflowBlockManifest): model_config = ConfigDict( json_schema_extra={ - "name": "OCR Model", + "name": "Text Recognition (OCR)", "version": "v1", "short_description": "Extract text from an image using optical character recognition.", "long_description": LONG_DESCRIPTION, @@ -72,6 +90,14 @@ class BlockManifest(WorkflowBlockManifest): type: Literal["roboflow_core/ocr_model@v1", "OCRModel"] name: str = Field(description="Unique name of step in workflows") images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField + model: ModelLiteral = Field( + default="doctr", + description="The OCR model to use.", + ) + google_cloud_api_key: Optional[str] = Field( + default=None, + description="API key for Google Cloud Vision, required if model is 'google-cloud-vision'.", + ) @classmethod def accepts_batch_input(cls) -> bool: @@ -92,7 +118,6 @@ def get_execution_engine_compatibility(cls) -> Optional[str]: class OCRModelBlockV1(WorkflowBlock): - # TODO: we need data model for OCR predictions def __init__( self, @@ -115,69 +140,34 @@ def get_manifest(cls) -> Type[WorkflowBlockManifest]: def run( self, images: Batch[WorkflowImageData], + model: str, + google_cloud_api_key: Optional[str] = None, ) -> BlockResult: - if self._step_execution_mode is StepExecutionMode.LOCAL: - return self.run_locally(images=images) - elif self._step_execution_mode is StepExecutionMode.REMOTE: - return self.run_remotely(images=images) - else: - raise ValueError( - f"Unknown step execution mode: {self._step_execution_mode}" - ) - - def run_locally( - self, - images: Batch[WorkflowImageData], - ) -> BlockResult: - predictions = [] - for single_image in images: - inference_request = DoctrOCRInferenceRequest( - image=single_image.to_inference_format(numpy_preferred=True), - api_key=self._api_key, - ) - doctr_model_id = load_core_model( - model_manager=self._model_manager, - inference_request=inference_request, - core_model="doctr", - ) - result = self._model_manager.infer_from_request_sync( - doctr_model_id, inference_request - ) - predictions.append(result.model_dump()) - return self._post_process_result( - predictions=predictions, + ocr_model = self._get_model_instance( + model=model, + google_cloud_api_key=google_cloud_api_key, + ) + return ocr_model.run( images=images, + step_execution_mode=self._step_execution_mode, + post_process_result=self._post_process_result, ) - def run_remotely( + def _get_model_instance( self, - images: Batch[WorkflowImageData], - ) -> BlockResult: - api_url = ( - LOCAL_INFERENCE_API_URL - if WORKFLOWS_REMOTE_API_TARGET != "hosted" - else HOSTED_CORE_MODEL_URL - ) - client = InferenceHTTPClient( - api_url=api_url, - api_key=self._api_key, - ) - if WORKFLOWS_REMOTE_API_TARGET == "hosted": - client.select_api_v0() - configuration = InferenceConfiguration( - max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, - max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, - ) - client.configure(configuration) - non_empty_inference_images = [i.numpy_image for i in images] - predictions = client.ocr_image( - inference_input=non_empty_inference_images, - ) - if len(images) == 1: - predictions = [predictions] - return self._post_process_result( - predictions=predictions, - images=images, + model: str, + **kwargs, + ) -> BaseOCRModel: + model_info = MODEL_REGISTRY.get(model) + if not model_info: + raise ValueError(f"Unknown model: {model}") + model_class = model_info["class"] + # Collect required fields for the model + required_fields = { + field: kwargs.get(field) for field in model_info.get("required_fields", []) + } + return model_class( + model_manager=self._model_manager, api_key=self._api_key, **required_fields ) def _post_process_result( From 21ce0c3e7ca82f9418737de9d19f7ae277ff1d73 Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 15:02:29 +0900 Subject: [PATCH 2/8] Add TrOCR --- .../models/foundation/ocr/models/trocr.py | 63 +++++++++++++++++++ .../core_steps/models/foundation/ocr/v1.py | 5 +- 2 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py new file mode 100644 index 000000000..317ef1146 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py @@ -0,0 +1,63 @@ +from typing import Callable, List + +from inference.core.entities.requests.trocr import TrOCRInferenceRequest +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.utils import load_core_model +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult + +from .base import BaseOCRModel + + +class TrOCRModel(BaseOCRModel): + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + if step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images, post_process_result) + elif step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images, post_process_result) + else: + raise ValueError(f"Unknown step execution mode: {step_execution_mode}") + + def run_locally( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for single_image in images: + inference_request = TrOCRInferenceRequest( + image=single_image.to_inference_format(numpy_preferred=True), + api_key=self.api_key, + ) + trocr_model_id = load_core_model( + model_manager=self.model_manager, + inference_request=inference_request, + core_model="trocr", + ) + result = self.model_manager.infer_from_request_sync( + trocr_model_id, inference_request + ) + predictions.append(result.model_dump()) + return post_process_result(images, predictions) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + raise NotImplementedError("Remote execution is not implemented for TrOCRModel.") diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index acd31449a..53ee9263b 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -41,7 +41,7 @@ from .models.base import BaseOCRModel from .models.doctr import DoctrOCRModel -from .models.trocr import TrOCRModel # Added import for TrOCRModel +from .models.trocr import TrOCRModel LONG_DESCRIPTION = """ Retrieve the characters in an image using Optical Character Recognition (OCR). @@ -73,7 +73,7 @@ }, } -ModelLiteral = Literal["doctr", "trocr"] # Updated to include 'trocr' +ModelLiteral = Literal["doctr", "trocr"] class BlockManifest(WorkflowBlockManifest): @@ -162,7 +162,6 @@ def _get_model_instance( if not model_info: raise ValueError(f"Unknown model: {model}") model_class = model_info["class"] - # Collect required fields for the model required_fields = { field: kwargs.get(field) for field in model_info.get("required_fields", []) } From fc33aab76d7f24993aad1acc5f4d3e5993e3b500 Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 15:03:10 +0900 Subject: [PATCH 3/8] Update Dockerfile.onnx.cpu --- docker/dockerfiles/Dockerfile.onnx.cpu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/dockerfiles/Dockerfile.onnx.cpu b/docker/dockerfiles/Dockerfile.onnx.cpu index 23a88e297..901e94157 100644 --- a/docker/dockerfiles/Dockerfile.onnx.cpu +++ b/docker/dockerfiles/Dockerfile.onnx.cpu @@ -40,6 +40,7 @@ RUN pip3 install --upgrade pip && pip3 install \ -r requirements.transformers.txt \ jupyterlab \ wheel>=0.38.0 \ + setuptools>=65.5.1 \ --upgrade \ && rm -rf ~/.cache/pip @@ -74,4 +75,5 @@ ENV API_LOGGING_ENABLED=True ENV CORE_MODEL_SAM2_ENABLED=True ENV CORE_MODEL_OWLV2_ENABLED=True -ENTRYPOINT uvicorn cpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT \ No newline at end of file +RUN pip install watchdog[watchmedo] +ENTRYPOINT watchmedo auto-restart --directory=/app/inference --pattern=*.py --recursive -- uvicorn cpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT \ No newline at end of file From f61305d8393dd8c032c9f3da08ab4cb12a1ccafb Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 19:22:28 +0900 Subject: [PATCH 4/8] Added Google Cloud Vision --- .../ocr/models/google_cloud_vision.py | 56 +++++++++++++++++++ .../core_steps/models/foundation/ocr/v1.py | 12 +++- 2 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py new file mode 100644 index 000000000..9e2b2b2f4 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py @@ -0,0 +1,56 @@ +# models/google_cloud_vision.py + +from .base import BaseOCRModel +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from typing import Optional, List +import requests +import base64 + + +class GoogleCloudVisionOCRModel(BaseOCRModel): + def __init__( + self, model_manager, api_key: Optional[str], google_cloud_api_key: str + ): + super().__init__(model_manager, api_key) + self.google_cloud_api_key = google_cloud_api_key + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result, + ): + predictions = [] + for image_data in images: + # Use base64_image directly + encoded_image = image_data.base64_image + url = f"https://vision.googleapis.com/v1/images:annotate?key={self.google_cloud_api_key}" + + payload = { + "requests": [ + { + "image": {"content": encoded_image}, + "features": [{"type": "TEXT_DETECTION"}], + } + ] + } + # Send the request + response = requests.post(url, json=payload) + if response.status_code == 200: + result = response.json() + text_annotations = result["responses"][0].get("textAnnotations", []) + if text_annotations: + text = text_annotations[0]["description"] + else: + text = "" + else: + error_info = response.json().get("error", {}) + message = error_info.get("message", response.text) + raise Exception(f"Google Cloud Vision API request failed: {message}") + prediction = {"result": text} + predictions.append(prediction) + return post_process_result(images, predictions) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index 53ee9263b..7bf1ae942 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -42,6 +42,7 @@ from .models.base import BaseOCRModel from .models.doctr import DoctrOCRModel from .models.trocr import TrOCRModel +from .models.google_cloud_vision import GoogleCloudVisionOCRModel LONG_DESCRIPTION = """ Retrieve the characters in an image using Optical Character Recognition (OCR). @@ -63,17 +64,22 @@ MODEL_REGISTRY = { "doctr": { "class": DoctrOCRModel, - "description": "Doctr OCR Model", + "description": "DocTR", "required_fields": [], }, "trocr": { "class": TrOCRModel, - "description": "TrOCR Model", + "description": "TrOCR", "required_fields": [], }, + "google-cloud-vision": { + "class": GoogleCloudVisionOCRModel, + "description": "Google Cloud Vision OCR", + "required_fields": ["google_cloud_api_key"], + }, } -ModelLiteral = Literal["doctr", "trocr"] +ModelLiteral = Literal["doctr", "trocr", "google-cloud-vision"] class BlockManifest(WorkflowBlockManifest): From 3f48e7b198c0bf3014a259dc59cce427a0d0f8fd Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 21:39:51 +0900 Subject: [PATCH 5/8] Add EasyOCR and Mathpix --- .../models/foundation/ocr/models/easyocr.py | 76 +++++++++++++ .../models/foundation/ocr/models/mathpix.py | 83 ++++++++++++++ .../models/foundation/ocr/models/trocr.py | 10 +- .../core_steps/models/foundation/ocr/v1.py | 105 +++++++++++++----- 4 files changed, 243 insertions(+), 31 deletions(-) create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py create mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py new file mode 100644 index 000000000..3c599b9de --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py @@ -0,0 +1,76 @@ +from .base import BaseOCRModel +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from inference.core.workflows.prototypes.block import BlockResult +from typing import Callable, List, Optional +import easyocr +import cv2 + + +class EasyOCRModel(BaseOCRModel): + def __init__( + self, + model_manager, + api_key: Optional[str], + easyocr_languages: List[str] = ["en"], + ): + super().__init__(model_manager, api_key) + self.reader = easyocr.Reader(easyocr_languages) + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + if step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images, post_process_result) + elif step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images, post_process_result) + + def run_locally( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for image_data in images: + # Convert image_data to numpy array + inference_image = image_data.to_inference_format( + numpy_preferred=True, + ) + img = inference_image["value"] + # Ensure image is in RGB format + if len(img.shape) == 3 and img.shape[2] == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + elif len(img.shape) == 2: + pass + else: + # Unsupported image format + raise ValueError("Unsupported image format") + # Run OCR + result = self.reader.readtext(img, detail=0) + text = " ".join(result) + prediction = {"result": text} + predictions.append(prediction) + return post_process_result(images, predictions) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + raise NotImplementedError( + "Remote execution is not implemented for EasyOCRModel." + ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py new file mode 100644 index 000000000..257d24983 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py @@ -0,0 +1,83 @@ +from .base import BaseOCRModel +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + WorkflowImageData, +) +from typing import Optional, List, Callable +from inference.core.workflows.prototypes.block import BlockResult + +import requests +import json +import base64 + + +class MathpixOCRModel(BaseOCRModel): + def __init__( + self, + model_manager, + api_key: Optional[str], + mathpix_app_id: str, + mathpix_app_key: str, + ): + super().__init__(model_manager, api_key) + self.mathpix_app_id = mathpix_app_id + self.mathpix_app_key = mathpix_app_key + + def run( + self, + images: Batch[WorkflowImageData], + step_execution_mode: StepExecutionMode, + post_process_result: Callable[ + [Batch[WorkflowImageData], List[dict]], BlockResult + ], + ) -> BlockResult: + predictions = [] + for image_data in images: + # Decode base64 image to bytes + image_bytes = base64.b64decode(image_data.base64_image) + + # Prepare the request + url = "https://api.mathpix.com/v3/text" + headers = { + "app_id": self.mathpix_app_id, + "app_key": self.mathpix_app_key, + } + data = { + "options_json": json.dumps( + { + "math_inline_delimiters": ["$", "$"], + "rm_spaces": True, + } + ) + } + files = {"file": ("image.jpg", image_bytes, "image/jpeg")} + + # Send the request + response = requests.post( + url, + headers=headers, + data=data, + files=files, + ) + + if response.status_code == 200: + result = response.json() + # Extract the text result + text = result.get("text", "") + else: + error_info = response.json().get("error", {}) + message = error_info.get("message", response.text) + detailed_message = error_info.get("detail", "") + + raise Exception( + f"Mathpix API request failed: {message} \n\n" + f"Detailed: {detailed_message}" + ) + + prediction = {"result": text} + predictions.append(prediction) + + return post_process_result(images, predictions) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py index 317ef1146..19d98fb32 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py @@ -1,7 +1,9 @@ from typing import Callable, List from inference.core.entities.requests.trocr import TrOCRInferenceRequest -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.core_steps.common.utils import load_core_model from inference.core.workflows.execution_engine.entities.base import ( Batch, @@ -26,8 +28,6 @@ def run( return self.run_locally(images, post_process_result) elif step_execution_mode is StepExecutionMode.REMOTE: return self.run_remotely(images, post_process_result) - else: - raise ValueError(f"Unknown step execution mode: {step_execution_mode}") def run_locally( self, @@ -60,4 +60,6 @@ def run_remotely( [Batch[WorkflowImageData], List[dict]], BlockResult ], ) -> BlockResult: - raise NotImplementedError("Remote execution is not implemented for TrOCRModel.") + raise NotImplementedError( + "Remote execution is not implemented for TrOCRModel.", + ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index 7bf1ae942..b2dddd4df 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -1,17 +1,11 @@ -from typing import Callable, Dict, List, Literal, Optional, Type, Union +from typing import List, Literal, Optional, Type, Union -from pydantic import ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field -from inference.core.entities.requests.inference import LMMInferenceRequest -from inference.core.env import ( - HOSTED_CORE_MODEL_URL, - LOCAL_INFERENCE_API_URL, - WORKFLOWS_REMOTE_API_TARGET, - WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, - WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, -) from inference.core.managers.base import ModelManager -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.core_steps.common.utils import ( remove_unexpected_keys_from_dictionary, ) @@ -43,22 +37,34 @@ from .models.doctr import DoctrOCRModel from .models.trocr import TrOCRModel from .models.google_cloud_vision import GoogleCloudVisionOCRModel +from .models.mathpix import MathpixOCRModel +from .models.easyocr import EasyOCRModel + +SHORT_DESCRIPTION = ( + "Extract text from an image using optical character recognition (OCR)." +) LONG_DESCRIPTION = """ - Retrieve the characters in an image using Optical Character Recognition (OCR). +Retrieve the characters in an image using Optical Character Recognition (OCR). This block returns the text within an image. -You may want to use this block in combination with a detections-based block (i.e. -ObjectDetectionBlock). An object detection model could isolate specific regions from an -image (i.e. a shipping container ID in a logistics use case) for further processing. -You can then use a DynamicCropBlock to crop the region of interest before running OCR. +You may want to use this block in combination with a detections-based block +(i.e. ObjectDetectionBlock). An object detection model could isolate specific +regions from an image (i.e. a shipping container ID in a logistics use case) +for further processing. You can then use a DynamicCropBlock to crop the region +of interest before running OCR. -Using a detections model then cropping detections allows you to isolate your analysis -on particular regions of an image. +Using a detections model then cropping detections allows you to isolate your +analysis on particular regions of an image. """ -EXPECTED_OUTPUT_KEYS = {"result", "parent_id", "root_parent_id", "prediction_type"} +EXPECTED_OUTPUT_KEYS = { + "result", + "parent_id", + "root_parent_id", + "prediction_type", +} # Registry of available models MODEL_REGISTRY = { @@ -77,9 +83,25 @@ "description": "Google Cloud Vision OCR", "required_fields": ["google_cloud_api_key"], }, + "mathpix": { + "class": MathpixOCRModel, + "description": "Mathpix Convert API", + "required_fields": ["mathpix_app_id", "mathpix_app_key"], + }, + "easyocr": { + "class": EasyOCRModel, + "description": "EasyOCR", + "required_fields": ["easyocr_languages"], + }, } -ModelLiteral = Literal["doctr", "trocr", "google-cloud-vision"] +ModelLiteral = Literal[ + "doctr", + "trocr", + "google-cloud-vision", + "mathpix", + "easyocr", +] class BlockManifest(WorkflowBlockManifest): @@ -87,7 +109,7 @@ class BlockManifest(WorkflowBlockManifest): json_schema_extra={ "name": "Text Recognition (OCR)", "version": "v1", - "short_description": "Extract text from an image using optical character recognition.", + "short_description": SHORT_DESCRIPTION, "long_description": LONG_DESCRIPTION, "license": "Apache-2.0", "block_type": "model", @@ -95,14 +117,29 @@ class BlockManifest(WorkflowBlockManifest): ) type: Literal["roboflow_core/ocr_model@v1", "OCRModel"] name: str = Field(description="Unique name of step in workflows") - images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField + images: Union[ + WorkflowImageSelector, + StepOutputImageSelector, + ] = ImageInputField model: ModelLiteral = Field( default="doctr", description="The OCR model to use.", ) google_cloud_api_key: Optional[str] = Field( default=None, - description="API key for Google Cloud Vision, required if model is 'google-cloud-vision'.", + description="API key for Google Cloud Vision.", + ) + mathpix_app_id: Optional[str] = Field( + default=None, + description="App ID for Mathpix API", + ) + mathpix_app_key: Optional[str] = Field( + default=None, + description="App Key for Mathpix API", + ) + easyocr_languages: Optional[List[str]] = Field( + default_factory=lambda: ["en"], + description="List of EasyOCR model languages.", ) @classmethod @@ -115,7 +152,10 @@ def describe_outputs(cls) -> List[OutputDefinition]: OutputDefinition(name="result", kind=[STRING_KIND]), OutputDefinition(name="parent_id", kind=[PARENT_ID_KIND]), OutputDefinition(name="root_parent_id", kind=[PARENT_ID_KIND]), - OutputDefinition(name="prediction_type", kind=[PREDICTION_TYPE_KIND]), + OutputDefinition( + name="prediction_type", + kind=[PREDICTION_TYPE_KIND], + ), ] @classmethod @@ -124,7 +164,6 @@ def get_execution_engine_compatibility(cls) -> Optional[str]: class OCRModelBlockV1(WorkflowBlock): - def __init__( self, model_manager: ModelManager, @@ -148,10 +187,16 @@ def run( images: Batch[WorkflowImageData], model: str, google_cloud_api_key: Optional[str] = None, + mathpix_app_id: Optional[str] = None, + mathpix_app_key: Optional[str] = None, + easyocr_languages: Optional[List[str]] = None, ) -> BlockResult: ocr_model = self._get_model_instance( model=model, google_cloud_api_key=google_cloud_api_key, + mathpix_app_id=mathpix_app_id, + mathpix_app_key=mathpix_app_key, + easyocr_languages=easyocr_languages, ) return ocr_model.run( images=images, @@ -169,10 +214,16 @@ def _get_model_instance( raise ValueError(f"Unknown model: {model}") model_class = model_info["class"] required_fields = { - field: kwargs.get(field) for field in model_info.get("required_fields", []) + field: kwargs.get(field) + for field in model_info.get( + "required_fields", + [], + ) } return model_class( - model_manager=self._model_manager, api_key=self._api_key, **required_fields + model_manager=self._model_manager, + api_key=self._api_key, + **required_fields, ) def _post_process_result( From 8d3d754d1ec492c91804be43bd3f28b71bdf6d05 Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 21:40:07 +0900 Subject: [PATCH 6/8] Formatting --- .../models/foundation/ocr/models/doctr.py | 6 ++--- .../ocr/models/google_cloud_vision.py | 22 +++++++++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py index 7e3bea591..573cb69c0 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py @@ -1,5 +1,7 @@ from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.core_steps.common.utils import load_core_model from inference.core.workflows.execution_engine.entities.base import ( Batch, @@ -25,8 +27,6 @@ def run( return self.run_locally(images, post_process_result) elif step_execution_mode is StepExecutionMode.REMOTE: return self.run_remotely(images, post_process_result) - else: - raise ValueError(f"Unknown step execution mode: {step_execution_mode}") def run_locally( self, diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py index 9e2b2b2f4..f249b8cd5 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py @@ -1,14 +1,15 @@ # models/google_cloud_vision.py from .base import BaseOCRModel -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.execution_engine.entities.base import ( Batch, WorkflowImageData, ) -from typing import Optional, List +from typing import Optional import requests -import base64 class GoogleCloudVisionOCRModel(BaseOCRModel): @@ -26,9 +27,11 @@ def run( ): predictions = [] for image_data in images: - # Use base64_image directly encoded_image = image_data.base64_image - url = f"https://vision.googleapis.com/v1/images:annotate?key={self.google_cloud_api_key}" + url = ( + f"https://vision.googleapis.com/v1/images:annotate" + f"?key={self.google_cloud_api_key}" + ) payload = { "requests": [ @@ -42,7 +45,10 @@ def run( response = requests.post(url, json=payload) if response.status_code == 200: result = response.json() - text_annotations = result["responses"][0].get("textAnnotations", []) + text_annotations = result["responses"][0].get( + "textAnnotations", + [], + ) if text_annotations: text = text_annotations[0]["description"] else: @@ -50,7 +56,9 @@ def run( else: error_info = response.json().get("error", {}) message = error_info.get("message", response.text) - raise Exception(f"Google Cloud Vision API request failed: {message}") + raise Exception( + f"Google Cloud Vision API request failed: {message}", + ) prediction = {"result": text} predictions.append(prediction) return post_process_result(images, predictions) From 5411c3f991963cabde3fc24d9f150c20bb2070ed Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Tue, 1 Oct 2024 21:40:22 +0900 Subject: [PATCH 7/8] More formatting --- .../core_steps/models/foundation/ocr/models/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py index e13db5153..b78364382 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/models/base.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Callable, List -from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.entities import ( + StepExecutionMode, +) from inference.core.workflows.execution_engine.entities.base import ( Batch, WorkflowImageData, From 89482b8f6c73ec10f1ea003a03eb721dfb60eb7f Mon Sep 17 00:00:00 2001 From: Leo Ueno Date: Wed, 2 Oct 2024 00:38:13 +0900 Subject: [PATCH 8/8] Remove EasyOCR This partially reverts commit 3f48e7b198c0bf3014a259dc59cce427a0d0f8fd. --- .../models/foundation/ocr/models/easyocr.py | 76 ------------------- .../core_steps/models/foundation/ocr/v1.py | 13 ---- 2 files changed, 89 deletions(-) delete mode 100644 inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py b/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py deleted file mode 100644 index 3c599b9de..000000000 --- a/inference/core/workflows/core_steps/models/foundation/ocr/models/easyocr.py +++ /dev/null @@ -1,76 +0,0 @@ -from .base import BaseOCRModel -from inference.core.workflows.core_steps.common.entities import ( - StepExecutionMode, -) -from inference.core.workflows.execution_engine.entities.base import ( - Batch, - WorkflowImageData, -) -from inference.core.workflows.prototypes.block import BlockResult -from typing import Callable, List, Optional -import easyocr -import cv2 - - -class EasyOCRModel(BaseOCRModel): - def __init__( - self, - model_manager, - api_key: Optional[str], - easyocr_languages: List[str] = ["en"], - ): - super().__init__(model_manager, api_key) - self.reader = easyocr.Reader(easyocr_languages) - - def run( - self, - images: Batch[WorkflowImageData], - step_execution_mode: StepExecutionMode, - post_process_result: Callable[ - [Batch[WorkflowImageData], List[dict]], BlockResult - ], - ) -> BlockResult: - if step_execution_mode is StepExecutionMode.LOCAL: - return self.run_locally(images, post_process_result) - elif step_execution_mode is StepExecutionMode.REMOTE: - return self.run_remotely(images, post_process_result) - - def run_locally( - self, - images: Batch[WorkflowImageData], - post_process_result: Callable[ - [Batch[WorkflowImageData], List[dict]], BlockResult - ], - ) -> BlockResult: - predictions = [] - for image_data in images: - # Convert image_data to numpy array - inference_image = image_data.to_inference_format( - numpy_preferred=True, - ) - img = inference_image["value"] - # Ensure image is in RGB format - if len(img.shape) == 3 and img.shape[2] == 3: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - elif len(img.shape) == 2: - pass - else: - # Unsupported image format - raise ValueError("Unsupported image format") - # Run OCR - result = self.reader.readtext(img, detail=0) - text = " ".join(result) - prediction = {"result": text} - predictions.append(prediction) - return post_process_result(images, predictions) - - def run_remotely( - self, - images: Batch[WorkflowImageData], - post_process_result: Callable[ - [Batch[WorkflowImageData], List[dict]], BlockResult - ], - ) -> BlockResult: - raise NotImplementedError( - "Remote execution is not implemented for EasyOCRModel." - ) diff --git a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py index b2dddd4df..26f77e452 100644 --- a/inference/core/workflows/core_steps/models/foundation/ocr/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/ocr/v1.py @@ -38,7 +38,6 @@ from .models.trocr import TrOCRModel from .models.google_cloud_vision import GoogleCloudVisionOCRModel from .models.mathpix import MathpixOCRModel -from .models.easyocr import EasyOCRModel SHORT_DESCRIPTION = ( "Extract text from an image using optical character recognition (OCR)." @@ -88,11 +87,6 @@ "description": "Mathpix Convert API", "required_fields": ["mathpix_app_id", "mathpix_app_key"], }, - "easyocr": { - "class": EasyOCRModel, - "description": "EasyOCR", - "required_fields": ["easyocr_languages"], - }, } ModelLiteral = Literal[ @@ -100,7 +94,6 @@ "trocr", "google-cloud-vision", "mathpix", - "easyocr", ] @@ -137,10 +130,6 @@ class BlockManifest(WorkflowBlockManifest): default=None, description="App Key for Mathpix API", ) - easyocr_languages: Optional[List[str]] = Field( - default_factory=lambda: ["en"], - description="List of EasyOCR model languages.", - ) @classmethod def accepts_batch_input(cls) -> bool: @@ -189,14 +178,12 @@ def run( google_cloud_api_key: Optional[str] = None, mathpix_app_id: Optional[str] = None, mathpix_app_key: Optional[str] = None, - easyocr_languages: Optional[List[str]] = None, ) -> BlockResult: ocr_model = self._get_model_instance( model=model, google_cloud_api_key=google_cloud_api_key, mathpix_app_id=mathpix_app_id, mathpix_app_key=mathpix_app_key, - easyocr_languages=easyocr_languages, ) return ocr_model.run( images=images,