Skip to content

Commit

Permalink
use label_whitelist instead of label_blacklist
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 22, 2023
1 parent 26c37d4 commit 87627ce
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/pie_datasets/document/processing/relation_argument_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ class RelationArgumentSorter:
Args:
relation_layer: the name of the relation layer
label_blacklist: if not None, the relations with the labels in the blacklist are not sorted
label_whitelist: if not None, only the relations with the label in the whitelist are sorted
inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done
on the copy
"""

def __init__(
self, relation_layer: str, label_blacklist: list[str] | None = None, inplace: bool = True
self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True
):
self.relation_layer = relation_layer
self.label_blacklist = label_blacklist
self.label_whitelist = label_whitelist
self.inplace = inplace

def __call__(self, doc: D) -> D:
Expand All @@ -79,8 +79,8 @@ def __call__(self, doc: D) -> D:

rel_layer.clear()
for args, rel in args2relations.items():
if self.label_blacklist is not None and rel.label in self.label_blacklist:
# just add the relations whose label is not in the label blacklist (if a blacklist is present)
if self.label_whitelist is not None and rel.label not in self.label_whitelist:
# just add the relations whose label is not in the label whitelist (if a whitelist is present)
rel_layer.append(rel)
else:
args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation
)


def test_relation_argument_sorter_with_label_blacklist(document):
def test_relation_argument_sorter_with_label_whitelist(document):
# argument of both relations are not sorted
document.binary_relations.append(
BinaryRelation(
Expand All @@ -138,9 +138,9 @@ def test_relation_argument_sorter_with_label_blacklist(document):
)
)

# we do not want to sort the arguments of the "worksAt" relation
# we only want to sort the relations with the label "founded"
arg_sorter = RelationArgumentSorter(
relation_layer="binary_relations", label_blacklist=["worksAt"], inplace=False
relation_layer="binary_relations", label_whitelist=["founded"], inplace=False
)
doc_sorted_args = arg_sorter(document)

Expand Down

0 comments on commit 87627ce

Please sign in to comment.