Skip to content

Commit

Permalink
Merge pull request #60 from ArneBinder/add_relation_argument_sorter
Browse files Browse the repository at this point in the history
add document processor: `RelationArgumentSorter`
  • Loading branch information
ArneBinder authored Nov 22, 2023
2 parents 0757685 + 87627ce commit bed7a4a
Show file tree
Hide file tree
Showing 3 changed files with 364 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pie_datasets/document/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .generic import Caster, Converter, Pipeline
from .regex_partitioner import RegexPartitioner
from .relation_argument_sorter import RelationArgumentSorter
from .text_span_trimmer import TextSpanTrimmer
from .tokenization import (
text_based_document_to_token_based,
Expand Down
107 changes: 107 additions & 0 deletions src/pie_datasets/document/processing/relation_argument_sorter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import logging
from typing import TypeVar

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, Document

logger = logging.getLogger(__name__)


D = TypeVar("D", bound=Document)


def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]:
if isinstance(relation, BinaryRelation):
return relation.head, relation.tail
else:
raise TypeError(
f"relation {relation} has unknown type [{type(relation)}], cannot get arguments from it"
)


def construct_relation_with_new_args(
relation: Annotation, new_args: tuple[Annotation, ...]
) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
return BinaryRelation(
head=new_args[0],
tail=new_args[1],
label=relation.label,
score=relation.score,
)
else:
raise TypeError(
f"original relation {relation} has unknown type [{type(relation)}], "
f"cannot reconstruct it with new arguments"
)


def has_dependent_layers(document: D, layer: str) -> bool:
return layer not in document._annotation_graph["_artificial_root"]


class RelationArgumentSorter:
"""Sorts the arguments of the relations in the given relation layer. The sorting is done by the
start and end positions of the arguments. The relations with the same sorted arguments are
merged into one relation.
Args:
relation_layer: the name of the relation layer
label_whitelist: if not None, only the relations with the label in the whitelist are sorted
inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done
on the copy
"""

def __init__(
self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True
):
self.relation_layer = relation_layer
self.label_whitelist = label_whitelist
self.inplace = inplace

def __call__(self, doc: D) -> D:
if not self.inplace:
doc = doc.copy()

rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer]
args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = {
get_relation_args(rel): rel for rel in rel_layer
}

# assert that no other layers depend on the relation layer
if has_dependent_layers(document=doc, layer=self.relation_layer):
raise ValueError(
f"the relation layer {self.relation_layer} has dependent layers, "
f"cannot sort the arguments of the relations"
)

rel_layer.clear()
for args, rel in args2relations.items():
if self.label_whitelist is not None and rel.label not in self.label_whitelist:
# just add the relations whose label is not in the label whitelist (if a whitelist is present)
rel_layer.append(rel)
else:
args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end)))
if args == args_sorted:
# if the relation args are already sorted, just add the relation
rel_layer.append(rel)
else:
if args_sorted not in args2relations:
new_rel = construct_relation_with_new_args(rel, args_sorted)
rel_layer.append(new_rel)
else:
prev_rel = args2relations[args_sorted]
if prev_rel.label != rel.label:
raise ValueError(
f"there is already a relation with sorted args {args_sorted} "
f"but with a different label: {prev_rel.label} != {rel.label}"
)
else:
logger.warning(
f"do not add the new relation with sorted arguments, because it is already there: "
f"{prev_rel}"
)

return doc
256 changes: 256 additions & 0 deletions tests/unit/document/processing/test_relation_argument_sorter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import dataclasses
import logging

import pytest
from pytorch_ie import Annotation, AnnotationLayer, annotation_field
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, NaryRelation
from pytorch_ie.documents import (
TextBasedDocument,
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndBinaryRelations,
)

from pie_datasets.document.processing import RelationArgumentSorter
from pie_datasets.document.processing.relation_argument_sorter import (
construct_relation_with_new_args,
get_relation_args,
)


@pytest.fixture
def document():
doc = TextDocumentWithLabeledSpansAndBinaryRelations(
text="Entity G works at H. And founded I."
)
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER"))
assert str(doc.labeled_spans[0]) == "Entity G"
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG"))
assert str(doc.labeled_spans[1]) == "H"
doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG"))
assert str(doc.labeled_spans[2]) == "I"

return doc


@pytest.mark.parametrize("inplace", [True, False])
def test_relation_argument_sorter(document, inplace):
# these arguments are not sorted
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt"
)
)
# these arguments are sorted
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[0], tail=document.labeled_spans[2], label="founded"
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace)
doc_sorted_args = arg_sorter(document)

assert document.text == doc_sorted_args.text
assert document.labeled_spans == doc_sorted_args.labeled_spans
assert len(doc_sorted_args.binary_relations) == len(document.binary_relations)

# this relation should be sorted
assert str(doc_sorted_args.binary_relations[0].head) == "Entity G"
assert str(doc_sorted_args.binary_relations[0].tail) == "H"
assert doc_sorted_args.binary_relations[0].label == "worksAt"

# this relation should be the same as before
assert str(doc_sorted_args.binary_relations[1].head) == "Entity G"
assert str(doc_sorted_args.binary_relations[1].tail) == "I"
assert doc_sorted_args.binary_relations[1].label == "founded"

if inplace:
assert document == doc_sorted_args
else:
assert document != doc_sorted_args


@pytest.fixture
def document_with_nary_relation():
@dataclasses.dataclass
class TextDocumentWithLabeledSpansAndNaryRelations(TextDocumentWithLabeledSpans):
nary_relations: AnnotationLayer[NaryRelation] = annotation_field(target="labeled_spans")

doc = TextDocumentWithLabeledSpansAndNaryRelations(text="Entity G works at H. And founded I.")
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER"))
assert str(doc.labeled_spans[0]) == "Entity G"
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG"))
assert str(doc.labeled_spans[1]) == "H"
doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG"))
assert str(doc.labeled_spans[2]) == "I"

doc.nary_relations.append(
NaryRelation(
arguments=(doc.labeled_spans[0], doc.labeled_spans[1], doc.labeled_spans[2]),
roles=("person", "worksAt", "founded"),
label="event",
)
)

return doc


def test_get_args_wrong_type(document_with_nary_relation):
with pytest.raises(TypeError) as excinfo:
get_relation_args(document_with_nary_relation.nary_relations[0])
assert (
str(excinfo.value)
== "relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, "
"label='ORG', score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) "
"has unknown type [<class 'pytorch_ie.annotations.NaryRelation'>], cannot get arguments from it"
)


def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation):
with pytest.raises(TypeError) as excinfo:
construct_relation_with_new_args(
document_with_nary_relation.nary_relations[0],
(
document_with_nary_relation.labeled_spans[0],
document_with_nary_relation.labeled_spans[1],
),
)
assert (
str(excinfo.value)
== "original relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, label='ORG', "
"score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) has unknown type "
"[<class 'pytorch_ie.annotations.NaryRelation'>], cannot reconstruct it with new arguments"
)


def test_relation_argument_sorter_with_label_whitelist(document):
# argument of both relations are not sorted
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt"
)
)
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[2], tail=document.labeled_spans[0], label="founded"
)
)

# we only want to sort the relations with the label "founded"
arg_sorter = RelationArgumentSorter(
relation_layer="binary_relations", label_whitelist=["founded"], inplace=False
)
doc_sorted_args = arg_sorter(document)

assert document.text == doc_sorted_args.text
assert document.labeled_spans == doc_sorted_args.labeled_spans

# this relation should be the same as before
assert doc_sorted_args.binary_relations[0] == document.binary_relations[0]

# this relation should be sorted
assert doc_sorted_args.binary_relations[1] != document.binary_relations[1]
assert str(doc_sorted_args.binary_relations[1].head) == "Entity G"
assert str(doc_sorted_args.binary_relations[1].tail) == "I"
assert doc_sorted_args.binary_relations[1].label == "founded"


def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(document, caplog):
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt"
)
)
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[0], tail=document.labeled_spans[1], label="worksAt"
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)

caplog.clear()
with caplog.at_level(logging.WARNING):
doc_sorted_args = arg_sorter(document)

assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert (
caplog.records[0].message
== "do not add the new relation with sorted arguments, because it is already there: "
"BinaryRelation(head=LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"tail=LabeledSpan(start=18, end=19, label='ORG', score=1.0), label='worksAt', score=1.0)"
)

assert document.text == doc_sorted_args.text
assert document.labeled_spans == doc_sorted_args.labeled_spans

# since there is already a relation with the same label and sorted arguments,
# there should be just one relation in the end
assert len(doc_sorted_args.binary_relations) == 1
assert str(doc_sorted_args.binary_relations[0].head) == "Entity G"
assert str(doc_sorted_args.binary_relations[0].tail) == "H"


def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label(document):
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt"
)
)
document.binary_relations.append(
BinaryRelation(
head=document.labeled_spans[0], tail=document.labeled_spans[1], label="founded"
)
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)

with pytest.raises(ValueError) as excinfo:
arg_sorter(document)
assert (
str(excinfo.value)
== "there is already a relation with sorted args (LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"LabeledSpan(start=18, end=19, label='ORG', score=1.0)) but with a different label: founded != worksAt"
)


def test_relation_argument_sorter_with_dependent_layers():
@dataclasses.dataclass(frozen=True)
class Attribute(Annotation):
annotation: Annotation
label: str

@dataclasses.dataclass
class ExampleDocument(TextBasedDocument):
labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(
target="labeled_spans"
)
relation_attributes: AnnotationLayer[Attribute] = annotation_field(
target="binary_relations"
)

doc = ExampleDocument(text="Entity G works at H. And founded I.")
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER"))
assert str(doc.labeled_spans[0]) == "Entity G"
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG"))
assert str(doc.labeled_spans[1]) == "H"
doc.binary_relations.append(
BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt")
)
doc.relation_attributes.append(
Attribute(annotation=doc.binary_relations[0], label="some_attribute")
)

arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False)

with pytest.raises(ValueError) as excinfo:
arg_sorter(doc)

assert (
str(excinfo.value)
== "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations"
)

0 comments on commit bed7a4a

Please sign in to comment.