Skip to content

Commit

Permalink
add TextSpanTrimmer (taken from https://github.com/ArneBinder/pie-utils)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2023
1 parent 4d3f839 commit 57ca8d3
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pie_datasets/document/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .text_span_trimmer import TextSpanTrimmer
116 changes: 116 additions & 0 deletions src/pie_datasets/document/processing/text_span_trimmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

import logging
from typing import TypeVar

from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, Document

logger = logging.getLogger(__name__)


D = TypeVar("D", bound=Document)


def trim_text_spans(
document: D,
layer: str,
skip_empty: bool = True,
verbose: bool = True,
) -> D:
"""Remove the whitespace at the beginning and end of span annotations that target a text field.
Args:
document: The document to trim its span annotations.
layer: The name of the span layer to trim.
skip_empty: If True, empty spans will be skipped. Otherwise, an error will be raised.
verbose: If True, log warnings for trimmed spans.
Returns:
The document with trimmed spans.
"""
annotation_layer_names = {f.name for f in document.annotation_fields()}
result = type(document).fromdict(
{k: v for k, v in document.asdict().items() if k not in annotation_layer_names}
)

spans: AnnotationList[LabeledSpan] = document[layer]

old2new_spans = {}
removed_span_ids = []

text = spans.target

for span in spans:
span_text = text[span.start : span.end]
new_start = span.start + len(span_text) - len(span_text.lstrip())
new_end = span.end - len(span_text) + len(span_text.rstrip())

if new_end <= new_start:
if skip_empty:
if verbose:
logger.warning(
f'Span "{span}" is empty after trimming. Skipping it. (disable this warning with verbose=False)'
)
removed_span_ids.append(span._id)
continue
else:
if verbose:
logger.warning(
f'Span "{span}" is empty after trimming. Keep it. (disable this warning with verbose=False)'
)
# if there was only whitespace, we create a span with length 0 at the start of the original span
if new_end < new_start:
new_start = span.start
new_end = span.start

new_span = LabeledSpan(
start=new_start,
end=new_end,
label=span.label,
score=span.score,
)
if (span.start != new_span.start or span.end != new_span.end) and verbose:
logger.debug(
f'Trimmed span "{span}" to "{new_span}" (disable this warning with verbose=False)'
)
old2new_spans[span._id] = new_span

result[layer].extend(old2new_spans.values())
result.add_all_annotations_from_other(
document,
override_annotations={layer: old2new_spans},
removed_annotations={layer: set(removed_span_ids)},
verbose=verbose,
strict=True,
)

return result


class TextSpanTrimmer:
"""Remove the whitespace at the beginning and end of span annotations that target a text field.
Args:
layer: The name of the text span layer to trim.
skip_empty: If True, empty spans will be skipped. Otherwise, an error will be raised.
verbose: If True, log warnings for trimmed spans.
"""

def __init__(
self,
layer: str,
skip_empty: bool = True,
verbose: bool = True,
):
self.layer = layer
self.skip_empty = skip_empty
self.verbose = verbose

def __call__(self, document: D) -> D:
return trim_text_spans(
document=document,
layer=self.layer,
skip_empty=self.skip_empty,
verbose=self.verbose,
)
119 changes: 119 additions & 0 deletions tests/unit/document/processing/test_text_span_trimmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import dataclasses

import pytest
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument

from pie_datasets.document.processing import TextSpanTrimmer


@dataclasses.dataclass
class DocumentWithEntitiesRelationsAndPartitions(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
partitions: AnnotationList[LabeledSpan] = annotation_field(target="text")


@pytest.fixture
def document1() -> DocumentWithEntitiesRelationsAndPartitions:
TEXT1 = "Jane lives in Berlin. this is a truncated sentence about Karl\n "
ENTITY_JANE_TEXT1 = LabeledSpan(start=0, end=4, label="person")
ENTITY_BERLIN_TEXT1 = LabeledSpan(start=13, end=20, label="city")
ENTITY_KARL_TEXT1 = LabeledSpan(start=57, end=61, label="person")
ENTITY_EMPTY_TEXT1 = LabeledSpan(start=62, end=65, label="other")
SENTENCE1_TEXT1 = LabeledSpan(start=0, end=21, label="sentence")
SENTENCE2_TEXT1 = LabeledSpan(start=22, end=65, label="sentence")
REL_JANE_LIVES_IN_BERLIN = BinaryRelation(
head=ENTITY_JANE_TEXT1, tail=ENTITY_BERLIN_TEXT1, label="lives_in"
)
REL_KARL_HAS_NOTHING = BinaryRelation(
head=ENTITY_KARL_TEXT1, tail=ENTITY_EMPTY_TEXT1, label="has_nothing"
)

document = DocumentWithEntitiesRelationsAndPartitions(text=TEXT1)
document.entities.extend(
[ENTITY_JANE_TEXT1, ENTITY_BERLIN_TEXT1, ENTITY_KARL_TEXT1, ENTITY_EMPTY_TEXT1]
)
document.partitions.extend([SENTENCE1_TEXT1, SENTENCE2_TEXT1])
document.relations.extend([REL_JANE_LIVES_IN_BERLIN, REL_KARL_HAS_NOTHING])

assert str(document.entities[0]) == "Jane"
assert str(document.entities[1]) == " Berlin"
assert str(document.entities[2]) == "Karl"
assert str(document.entities[3]) == " "
assert str(document.partitions[0]) == "Jane lives in Berlin."
assert str(document.partitions[1]) == "this is a truncated sentence about Karl\n "

assert str(document.relations[0].tail) == " Berlin"
assert str(document.relations[0].head) == "Jane"
assert str(document.relations[0].label) == "lives_in"
assert str(document.relations[1].tail) == " "
assert str(document.relations[1].head) == "Karl"
assert str(document.relations[1].label) == "has_nothing"

return document


@pytest.mark.parametrize(
"layer,skip_empty",
[
("entities", False),
("partitions", False),
("partitions", True),
],
)
def test_text_span_trimmer(document1, layer, skip_empty):
trimmer = TextSpanTrimmer(layer=layer, skip_empty=skip_empty)
processed_document = trimmer(document1)

assert len(document1.entities) == 4
assert len(document1.relations) == 2
assert len(processed_document.partitions) == len(document1.partitions) == 2

if layer == "entities" and not skip_empty:
assert len(processed_document.entities) == 4
assert len(processed_document.relations) == 2
assert str(processed_document.entities[0]) == "Jane"
assert str(processed_document.entities[1]) == "Berlin"
assert str(processed_document.entities[2]) == "Karl"
assert str(processed_document.entities[3]) == ""
assert str(processed_document.partitions[0]) == "Jane lives in Berlin."
assert (
str(processed_document.partitions[1]) == "this is a truncated sentence about Karl\n "
)
assert str(processed_document.relations[0].tail) == "Berlin"
assert str(processed_document.relations[0].head) == "Jane"
assert str(processed_document.relations[0].label) == "lives_in"
assert str(processed_document.relations[1].tail) == ""
assert str(processed_document.relations[1].head) == "Karl"
assert str(processed_document.relations[1].label) == "has_nothing"
elif layer == "partitions":
assert len(processed_document.entities) == 4
assert str(processed_document.entities[0]) == "Jane"
assert str(processed_document.entities[1]) == " Berlin"
assert str(processed_document.entities[2]) == "Karl"
assert str(processed_document.entities[3]) == " "
assert str(processed_document.partitions[0]) == "Jane lives in Berlin."
assert str(processed_document.partitions[1]) == "this is a truncated sentence about Karl"
assert str(processed_document.relations[0].tail) == " Berlin"
assert str(processed_document.relations[0].head) == "Jane"
assert str(processed_document.relations[0].label) == "lives_in"
assert str(processed_document.relations[1].tail) == " "
assert str(processed_document.relations[1].head) == "Karl"
assert str(processed_document.relations[1].label) == "has_nothing"
else:
raise ValueError(f"Unknown parameter combination: layer={layer}, skip_empty={skip_empty}")


def test_text_span_trimmer_remove_entity_of_relations(document1):
trimmer = TextSpanTrimmer(layer="entities", skip_empty=True)
with pytest.raises(ValueError) as excinfo:
processed_document = trimmer(document1)
assert (
str(excinfo.value)
== "Could not add annotation BinaryRelation(head=LabeledSpan(start=57, end=61, label='person', score=1.0), "
"tail=LabeledSpan(start=62, end=65, label='other', score=1.0), label='has_nothing', score=1.0) "
"to DocumentWithEntitiesRelationsAndPartitions because it depends on annotations that are not present "
"in the document."
)

0 comments on commit 57ca8d3

Please sign in to comment.