Skip to content

Commit

Permalink
Added support for Custom Queries
Browse files Browse the repository at this point in the history
  • Loading branch information
kmascar committed Oct 13, 2023
1 parent dba7a1f commit 871ef3f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
29 changes: 28 additions & 1 deletion caller/tests/test_caller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from textractcaller import call_textract, call_textract_analyzeid, QueriesConfig, Query, get_full_json_from_output_config, get_full_json, call_textract_lending, get_full_json_lending
from textractcaller import call_textract, call_textract_analyzeid, QueriesConfig, Query, AdaptersConfig, Adapter, get_full_json_from_output_config, get_full_json, call_textract_lending, get_full_json_lending
from textractcaller.t_call import OutputConfig, Textract_Features, call_textract_expense, remove_none
from trp import Document
import trp.trp2 as t2
Expand Down Expand Up @@ -177,6 +177,33 @@ def test_queries(caplog):
query_answers = tdoc.get_query_answers(page=page)
assert len(query_answers) == 3

def test_custom_queries(caplog):
caplog.set_level(logging.DEBUG, logger="textractcaller")
queries_config = QueriesConfig(queries=[])
assert not queries_config.get_dict()
query1 = Query(text="What is the applicant full name?")
query2 = Query(text="What is the applicant phone number?", alias="PHONE_NUMBER")
query3 = Query(text="What is the applicant home address?", alias="HOME_ADDRESS", pages=["1"])
queries_config = QueriesConfig(queries=[query1, query2, query3])
adapters_config = AdaptersConfig(adapters=[])
assert not adapters_config.get_dict()
adapter1 = Adapter(adapter_id="2e9bf1c4aa31", version=1, pages=["1"])
adapters_config = AdaptersConfig(adapters=[adapter1])

textract_client = boto3.client("textract", region_name="us-east-2")
j = call_textract(
input_document="s3://amazon-textract-public-content/blogs/employeeapp20210510.png",
boto3_textract_client=textract_client,
features=[Textract_Features.QUERIES],
queries_config=queries_config,
adapters_config=adapters_config
)
assert j
tdoc: t2.TDocument = t2.TDocumentSchema().load(j) # type: ignore
assert tdoc
page = tdoc.pages[0]
query_answers = tdoc.get_query_answers(page=page)
assert len(query_answers) == 3

def test_empty_features_and_queries(caplog):
caplog.set_level(logging.DEBUG, logger="textractcaller")
Expand Down
2 changes: 1 addition & 1 deletion caller/textractcaller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._version import __version__
from .t_call import NotificationChannel, OutputConfig, DocumentLocation, Document, get_job_response, get_full_json_from_output_config, get_full_json, call_textract, Textract_Features, call_textract_analyzeid, DocumentPage, QueriesConfig, Query, call_textract_expense, Textract_Call_Mode, Textract_API, Textract_Types, call_textract_lending, get_full_json_lending, get_full_json_lending_from_output_config, get_s3_output_config_keys
from .t_call import NotificationChannel, OutputConfig, DocumentLocation, Document, get_job_response, get_full_json_from_output_config, get_full_json, call_textract, Textract_Features, call_textract_analyzeid, DocumentPage, QueriesConfig, Query, AdaptersConfig, Adapter, call_textract_expense, Textract_Call_Mode, Textract_API, Textract_Types, call_textract_lending, get_full_json_lending, get_full_json_lending_from_output_config, get_s3_output_config_keys

import logging
from logging import NullHandler
Expand Down
35 changes: 33 additions & 2 deletions caller/textractcaller/t_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,31 @@ def get_dict(self):
else:
return {}

@dataclass
class Adapter():
adapter_id: str
version: int
pages: List[str] = field(default_factory=list)

def get_dict(self):
return_dict: dict = {"AdapterId": self.adapter_id}
if self.alias:
return_dict["Version"] = self.version
if self.pages:
return_dict["Pages"] = self.pages # type: ignore
return return_dict


@dataclass
class AdaptersConfig():
adapters: List[Adapter]

def get_dict(self):
if self.adapters:
return {"Adapters": [x.get_dict() for x in self.adapters]}
else:
return {}


@dataclass
class Document():
Expand Down Expand Up @@ -143,6 +168,7 @@ def generate_request_params(document_location: Optional[DocumentLocation] = None
document: Optional[Document] = None,
features: Optional[List[Textract_Features]] = None,
queries_config: Optional[QueriesConfig] = None,
adapters_config: Optional[AdaptersConfig] = None,
client_request_token: str = "",
job_tag: str = "",
notification_channel: Optional[NotificationChannel] = None,
Expand Down Expand Up @@ -432,6 +458,7 @@ def call_textract(input_document: Union[str, bytes],
features: Optional[List[Textract_Features]] = None,
queries_config: Optional[QueriesConfig] = None,
output_config: Optional[OutputConfig] = None,
adapters_config: Optional[AdaptersConfig] = None,
kms_key_id: str = "",
job_tag: str = "",
notification_channel: Optional[NotificationChannel] = None,
Expand Down Expand Up @@ -498,6 +525,7 @@ def call_textract(input_document: Union[str, bytes],
document_location=DocumentLocation(s3_bucket=s3_bucket, s3_prefix=s3_key),
features=features,
queries_config=queries_config,
adapters_config=adapters_config,
output_config=output_config,
notification_channel=notification_channel,
kms_key_id=kms_key_id,
Expand Down Expand Up @@ -530,6 +558,7 @@ def call_textract(input_document: Union[str, bytes],
params = generate_request_params(document=Document(s3_bucket=s3_bucket, s3_prefix=s3_key),
features=features,
queries_config=queries_config,
adapters_config=adapters_config,
output_config=output_config,
kms_key_id=kms_key_id,
notification_channel=notification_channel)
Expand All @@ -543,7 +572,8 @@ def call_textract(input_document: Union[str, bytes],
doc_bytes: bytearray = bytearray(input_file.read())
params = generate_request_params(document=Document(byte_data=doc_bytes),
features=features,
queries_config=queries_config)
queries_config=queries_config,
adapters_config=adapters_config)

if features:
result_value = textract.analyze_document(**params)
Expand All @@ -557,7 +587,8 @@ def call_textract(input_document: Union[str, bytes],
raise Exception("cannot run async for bytearray")
params = generate_request_params(document=Document(byte_data=input_document),
features=features,
queries_config=queries_config)
queries_config=queries_config,
adapters_config=adapters_config)
if features:
result_value = textract.analyze_document(**params)
else:
Expand Down

0 comments on commit 871ef3f

Please sign in to comment.