-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #60 from ArneBinder/add_relation_argument_sorter
add document processor: `RelationArgumentSorter`
- Loading branch information
Showing
3 changed files
with
364 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
src/pie_datasets/document/processing/relation_argument_sorter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TypeVar | ||
|
||
from pytorch_ie.annotations import BinaryRelation, LabeledSpan | ||
from pytorch_ie.core import Annotation, AnnotationList, Document | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
D = TypeVar("D", bound=Document) | ||
|
||
|
||
def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]: | ||
if isinstance(relation, BinaryRelation): | ||
return relation.head, relation.tail | ||
else: | ||
raise TypeError( | ||
f"relation {relation} has unknown type [{type(relation)}], cannot get arguments from it" | ||
) | ||
|
||
|
||
def construct_relation_with_new_args( | ||
relation: Annotation, new_args: tuple[Annotation, ...] | ||
) -> BinaryRelation: | ||
if isinstance(relation, BinaryRelation): | ||
return BinaryRelation( | ||
head=new_args[0], | ||
tail=new_args[1], | ||
label=relation.label, | ||
score=relation.score, | ||
) | ||
else: | ||
raise TypeError( | ||
f"original relation {relation} has unknown type [{type(relation)}], " | ||
f"cannot reconstruct it with new arguments" | ||
) | ||
|
||
|
||
def has_dependent_layers(document: D, layer: str) -> bool: | ||
return layer not in document._annotation_graph["_artificial_root"] | ||
|
||
|
||
class RelationArgumentSorter: | ||
"""Sorts the arguments of the relations in the given relation layer. The sorting is done by the | ||
start and end positions of the arguments. The relations with the same sorted arguments are | ||
merged into one relation. | ||
Args: | ||
relation_layer: the name of the relation layer | ||
label_whitelist: if not None, only the relations with the label in the whitelist are sorted | ||
inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done | ||
on the copy | ||
""" | ||
|
||
def __init__( | ||
self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True | ||
): | ||
self.relation_layer = relation_layer | ||
self.label_whitelist = label_whitelist | ||
self.inplace = inplace | ||
|
||
def __call__(self, doc: D) -> D: | ||
if not self.inplace: | ||
doc = doc.copy() | ||
|
||
rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer] | ||
args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = { | ||
get_relation_args(rel): rel for rel in rel_layer | ||
} | ||
|
||
# assert that no other layers depend on the relation layer | ||
if has_dependent_layers(document=doc, layer=self.relation_layer): | ||
raise ValueError( | ||
f"the relation layer {self.relation_layer} has dependent layers, " | ||
f"cannot sort the arguments of the relations" | ||
) | ||
|
||
rel_layer.clear() | ||
for args, rel in args2relations.items(): | ||
if self.label_whitelist is not None and rel.label not in self.label_whitelist: | ||
# just add the relations whose label is not in the label whitelist (if a whitelist is present) | ||
rel_layer.append(rel) | ||
else: | ||
args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) | ||
if args == args_sorted: | ||
# if the relation args are already sorted, just add the relation | ||
rel_layer.append(rel) | ||
else: | ||
if args_sorted not in args2relations: | ||
new_rel = construct_relation_with_new_args(rel, args_sorted) | ||
rel_layer.append(new_rel) | ||
else: | ||
prev_rel = args2relations[args_sorted] | ||
if prev_rel.label != rel.label: | ||
raise ValueError( | ||
f"there is already a relation with sorted args {args_sorted} " | ||
f"but with a different label: {prev_rel.label} != {rel.label}" | ||
) | ||
else: | ||
logger.warning( | ||
f"do not add the new relation with sorted arguments, because it is already there: " | ||
f"{prev_rel}" | ||
) | ||
|
||
return doc |
256 changes: 256 additions & 0 deletions
256
tests/unit/document/processing/test_relation_argument_sorter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import dataclasses | ||
import logging | ||
|
||
import pytest | ||
from pytorch_ie import Annotation, AnnotationLayer, annotation_field | ||
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, NaryRelation | ||
from pytorch_ie.documents import ( | ||
TextBasedDocument, | ||
TextDocumentWithLabeledSpans, | ||
TextDocumentWithLabeledSpansAndBinaryRelations, | ||
) | ||
|
||
from pie_datasets.document.processing import RelationArgumentSorter | ||
from pie_datasets.document.processing.relation_argument_sorter import ( | ||
construct_relation_with_new_args, | ||
get_relation_args, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def document(): | ||
doc = TextDocumentWithLabeledSpansAndBinaryRelations( | ||
text="Entity G works at H. And founded I." | ||
) | ||
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) | ||
assert str(doc.labeled_spans[0]) == "Entity G" | ||
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) | ||
assert str(doc.labeled_spans[1]) == "H" | ||
doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG")) | ||
assert str(doc.labeled_spans[2]) == "I" | ||
|
||
return doc | ||
|
||
|
||
@pytest.mark.parametrize("inplace", [True, False]) | ||
def test_relation_argument_sorter(document, inplace): | ||
# these arguments are not sorted | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" | ||
) | ||
) | ||
# these arguments are sorted | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[0], tail=document.labeled_spans[2], label="founded" | ||
) | ||
) | ||
|
||
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace) | ||
doc_sorted_args = arg_sorter(document) | ||
|
||
assert document.text == doc_sorted_args.text | ||
assert document.labeled_spans == doc_sorted_args.labeled_spans | ||
assert len(doc_sorted_args.binary_relations) == len(document.binary_relations) | ||
|
||
# this relation should be sorted | ||
assert str(doc_sorted_args.binary_relations[0].head) == "Entity G" | ||
assert str(doc_sorted_args.binary_relations[0].tail) == "H" | ||
assert doc_sorted_args.binary_relations[0].label == "worksAt" | ||
|
||
# this relation should be the same as before | ||
assert str(doc_sorted_args.binary_relations[1].head) == "Entity G" | ||
assert str(doc_sorted_args.binary_relations[1].tail) == "I" | ||
assert doc_sorted_args.binary_relations[1].label == "founded" | ||
|
||
if inplace: | ||
assert document == doc_sorted_args | ||
else: | ||
assert document != doc_sorted_args | ||
|
||
|
||
@pytest.fixture | ||
def document_with_nary_relation(): | ||
@dataclasses.dataclass | ||
class TextDocumentWithLabeledSpansAndNaryRelations(TextDocumentWithLabeledSpans): | ||
nary_relations: AnnotationLayer[NaryRelation] = annotation_field(target="labeled_spans") | ||
|
||
doc = TextDocumentWithLabeledSpansAndNaryRelations(text="Entity G works at H. And founded I.") | ||
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) | ||
assert str(doc.labeled_spans[0]) == "Entity G" | ||
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) | ||
assert str(doc.labeled_spans[1]) == "H" | ||
doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG")) | ||
assert str(doc.labeled_spans[2]) == "I" | ||
|
||
doc.nary_relations.append( | ||
NaryRelation( | ||
arguments=(doc.labeled_spans[0], doc.labeled_spans[1], doc.labeled_spans[2]), | ||
roles=("person", "worksAt", "founded"), | ||
label="event", | ||
) | ||
) | ||
|
||
return doc | ||
|
||
|
||
def test_get_args_wrong_type(document_with_nary_relation): | ||
with pytest.raises(TypeError) as excinfo: | ||
get_relation_args(document_with_nary_relation.nary_relations[0]) | ||
assert ( | ||
str(excinfo.value) | ||
== "relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " | ||
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, " | ||
"label='ORG', score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) " | ||
"has unknown type [<class 'pytorch_ie.annotations.NaryRelation'>], cannot get arguments from it" | ||
) | ||
|
||
|
||
def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation): | ||
with pytest.raises(TypeError) as excinfo: | ||
construct_relation_with_new_args( | ||
document_with_nary_relation.nary_relations[0], | ||
( | ||
document_with_nary_relation.labeled_spans[0], | ||
document_with_nary_relation.labeled_spans[1], | ||
), | ||
) | ||
assert ( | ||
str(excinfo.value) | ||
== "original relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " | ||
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, label='ORG', " | ||
"score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) has unknown type " | ||
"[<class 'pytorch_ie.annotations.NaryRelation'>], cannot reconstruct it with new arguments" | ||
) | ||
|
||
|
||
def test_relation_argument_sorter_with_label_whitelist(document): | ||
# argument of both relations are not sorted | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" | ||
) | ||
) | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[2], tail=document.labeled_spans[0], label="founded" | ||
) | ||
) | ||
|
||
# we only want to sort the relations with the label "founded" | ||
arg_sorter = RelationArgumentSorter( | ||
relation_layer="binary_relations", label_whitelist=["founded"], inplace=False | ||
) | ||
doc_sorted_args = arg_sorter(document) | ||
|
||
assert document.text == doc_sorted_args.text | ||
assert document.labeled_spans == doc_sorted_args.labeled_spans | ||
|
||
# this relation should be the same as before | ||
assert doc_sorted_args.binary_relations[0] == document.binary_relations[0] | ||
|
||
# this relation should be sorted | ||
assert doc_sorted_args.binary_relations[1] != document.binary_relations[1] | ||
assert str(doc_sorted_args.binary_relations[1].head) == "Entity G" | ||
assert str(doc_sorted_args.binary_relations[1].tail) == "I" | ||
assert doc_sorted_args.binary_relations[1].label == "founded" | ||
|
||
|
||
def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(document, caplog): | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" | ||
) | ||
) | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[0], tail=document.labeled_spans[1], label="worksAt" | ||
) | ||
) | ||
|
||
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) | ||
|
||
caplog.clear() | ||
with caplog.at_level(logging.WARNING): | ||
doc_sorted_args = arg_sorter(document) | ||
|
||
assert len(caplog.records) == 1 | ||
assert caplog.records[0].levelname == "WARNING" | ||
assert ( | ||
caplog.records[0].message | ||
== "do not add the new relation with sorted arguments, because it is already there: " | ||
"BinaryRelation(head=LabeledSpan(start=0, end=8, label='PER', score=1.0), " | ||
"tail=LabeledSpan(start=18, end=19, label='ORG', score=1.0), label='worksAt', score=1.0)" | ||
) | ||
|
||
assert document.text == doc_sorted_args.text | ||
assert document.labeled_spans == doc_sorted_args.labeled_spans | ||
|
||
# since there is already a relation with the same label and sorted arguments, | ||
# there should be just one relation in the end | ||
assert len(doc_sorted_args.binary_relations) == 1 | ||
assert str(doc_sorted_args.binary_relations[0].head) == "Entity G" | ||
assert str(doc_sorted_args.binary_relations[0].tail) == "H" | ||
|
||
|
||
def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label(document): | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" | ||
) | ||
) | ||
document.binary_relations.append( | ||
BinaryRelation( | ||
head=document.labeled_spans[0], tail=document.labeled_spans[1], label="founded" | ||
) | ||
) | ||
|
||
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) | ||
|
||
with pytest.raises(ValueError) as excinfo: | ||
arg_sorter(document) | ||
assert ( | ||
str(excinfo.value) | ||
== "there is already a relation with sorted args (LabeledSpan(start=0, end=8, label='PER', score=1.0), " | ||
"LabeledSpan(start=18, end=19, label='ORG', score=1.0)) but with a different label: founded != worksAt" | ||
) | ||
|
||
|
||
def test_relation_argument_sorter_with_dependent_layers(): | ||
@dataclasses.dataclass(frozen=True) | ||
class Attribute(Annotation): | ||
annotation: Annotation | ||
label: str | ||
|
||
@dataclasses.dataclass | ||
class ExampleDocument(TextBasedDocument): | ||
labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") | ||
binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( | ||
target="labeled_spans" | ||
) | ||
relation_attributes: AnnotationLayer[Attribute] = annotation_field( | ||
target="binary_relations" | ||
) | ||
|
||
doc = ExampleDocument(text="Entity G works at H. And founded I.") | ||
doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) | ||
assert str(doc.labeled_spans[0]) == "Entity G" | ||
doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) | ||
assert str(doc.labeled_spans[1]) == "H" | ||
doc.binary_relations.append( | ||
BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt") | ||
) | ||
doc.relation_attributes.append( | ||
Attribute(annotation=doc.binary_relations[0], label="some_attribute") | ||
) | ||
|
||
arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) | ||
|
||
with pytest.raises(ValueError) as excinfo: | ||
arg_sorter(doc) | ||
|
||
assert ( | ||
str(excinfo.value) | ||
== "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations" | ||
) |