Skip to content

Commit

Permalink
Merge pull request #14 from climatepolicyradar/feature/rnd-1153-refor…
Browse files Browse the repository at this point in the history
…mat-huggingface-dataset

Adding method to the ParserOutput object to get row wise json.
  • Loading branch information
THOR300 authored Apr 11, 2024
2 parents 5e21d2b + f6395fd commit 6cc1d9b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 7 deletions.
54 changes: 53 additions & 1 deletion src/cpr_sdk/parser_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from collections import Counter
from datetime import date
from enum import Enum
from typing import List, Optional, Sequence, Tuple, TypeVar, Union
import json
from typing import List, Optional, Sequence, Tuple, TypeVar, Union, Any

from cpr_sdk.pipeline_general_models import (
CONTENT_TYPE_HTML,
Expand Down Expand Up @@ -373,3 +374,54 @@ def from_flat_json(data: dict):
unflattened = remove_key_if_all_nested_vals_none(unflattened, "pdf_data")

return ParserOutput.model_validate(unflattened)

def to_passage_level_json(self) -> list[dict[str, Any]]:
"""
Convert the parser output to a passage-level JSON format.
In passage-level format we have a row for every text block in the document. This
is as for natural language processing tasks we often want to work with text at
the passage level.
HTML data won't contain PDF fields and vice versa, thus we must fill this in.
We could rely on the hugging face dataset transformation to fill in the missing
fields, but this is more explicit and provides default values.
The reason we convert from the pydantic BaseModel to a string using the
model_dump_json method and then reloading with json.load is as objects like
Enums and child pydantic objects persist when using the model_dump method.
We don't want these when we push to huggingface.
"""
if self.text_blocks is None:
return []

common_fields_dict = json.loads(
self.model_dump_json(
exclude={
"pdf_data": {"text_blocks", "page_metadata"},
"html_data": {"text_blocks"},
}
)
)

passages_array = [
common_fields_dict
| json.loads(block.model_dump_json(exclude={"text"}))
| {"text": block.to_string(), "block_index": idx}
for idx, block in enumerate(self.text_blocks)
]

empty_html_text_block_keys: list[str] = list(HTMLTextBlock.model_fields.keys())
empty_pdf_text_block_keys: list[str] = list(PDFTextBlock.model_fields.keys())

passages_array_filled = []
for passage in passages_array:
for key in empty_html_text_block_keys:
if key not in passage:
passage[key] = None
for key in empty_pdf_text_block_keys:
if key not in passage:
passage[key] = None
passages_array_filled.append(passage)

return passages_array_filled
4 changes: 2 additions & 2 deletions src/cpr_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_MAJOR = "1"
_MINOR = "0"
_PATCH = "2"
_MINOR = "1"
_PATCH = "0"
_SUFFIX = ""

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
Expand Down
52 changes: 48 additions & 4 deletions tests/test_parser_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pydantic
import pytest

from cpr_sdk.parser_models import (
ParserInput,
ParserOutput,
PDFTextBlock,
VerticalFlipError,
HTMLTextBlock,
TextBlock,
)
from cpr_sdk.pipeline_general_models import (
CONTENT_TYPE_HTML,
CONTENT_TYPE_PDF,
)
from cpr_sdk.pipeline_general_models import CONTENT_TYPE_HTML, CONTENT_TYPE_PDF


def test_parser_input_object(parser_output_json_pdf) -> None:
Expand Down Expand Up @@ -150,3 +150,47 @@ def test_parser_output_object(
with pytest.raises(pydantic.ValidationError) as context:
ParserOutput.model_validate(parser_output_json_flat)
parser_output = ParserOutput.from_flat_json(parser_output_json_flat)


def test_to_passage_level_json_method(
parser_output_json_pdf: dict,
parser_output_json_html: dict,
) -> None:
"""Test that we can successfully create a passage level array from the text blocks."""
parser_output_pdf = ParserOutput.model_validate(parser_output_json_pdf)
passage_level_array_pdf = parser_output_pdf.to_passage_level_json()

parser_output_html = ParserOutput.model_validate(parser_output_json_html)
passage_level_array_html = parser_output_html.to_passage_level_json()

assert len(passage_level_array_pdf) == len(parser_output_pdf.text_blocks)
assert len(passage_level_array_html) == len(parser_output_html.text_blocks)

for passage_level_array in [passage_level_array_pdf, passage_level_array_html]:
assert all(isinstance(passage, dict) for passage in passage_level_array)

first_doc_keys = set(passage_level_array[0].keys())
assert all(
set(passage.keys()) == first_doc_keys for passage in passage_level_array
)

expected_model_fields = set(
list(TextBlock.model_fields.keys())
+ list(HTMLTextBlock.model_fields.keys())
+ list(PDFTextBlock.model_fields.keys())
+ list(ParserOutput.model_fields.keys())
+ ["block_index"]
)

assert all(
set(passage.keys()) == expected_model_fields
for passage in passage_level_array
)

passage_level_array_pdf_first_doc = passage_level_array_pdf[0]
passage_level_array_html_first_doc = passage_level_array_html[0]

assert (
passage_level_array_pdf_first_doc.keys()
== passage_level_array_html_first_doc.keys()
)

0 comments on commit 6cc1d9b

Please sign in to comment.