Skip to content

Commit

Permalink
Merge pull request #35 from ArneBinder/remove_some_statistics
Browse files Browse the repository at this point in the history
remove some statistics
  • Loading branch information
ArneBinder authored Nov 8, 2023
2 parents ae3dbec + b508ee9 commit f2925d1
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 263 deletions.
1 change: 1 addition & 0 deletions src/pie_datasets/statistics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .span_length_collector import SpanLengthCollector
Original file line number Diff line number Diff line change
Expand Up @@ -13,70 +13,6 @@
logger = logging.getLogger(__name__)


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface
tokenizer.
The content of the field should be a string.
"""

def __init__(
self,
tokenizer: Union[str, PreTrainedTokenizer],
text_field: str = "text",
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
document_type: Optional[Type[Document]] = None,
**kwargs,
):
if document_type is None and text_field == "text":
document_type = TextBasedDocument
super().__init__(document_type=document_type, **kwargs)
self.tokenizer = (
AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer
)
self.tokenizer_kwargs = tokenizer_kwargs or {}
self.text_field = text_field

def _collect(self, doc: Document) -> int:
text = getattr(doc, self.text_field)
encodings = self.tokenizer(text, **self.tokenizer_kwargs)
tokens = encodings.tokens()
return len(tokens)


class FieldLengthCollector(DocumentStatistic):
"""Collects the length of a field, e.g. to collect the number the characters in the input text.
The field should be a list of sized elements.
"""

def __init__(self, field: str, **kwargs):
super().__init__(**kwargs)
self.field = field

def _collect(self, doc: Document) -> int:
field_obj = getattr(doc, self.field)
return len(field_obj)


class SubFieldLengthCollector(DocumentStatistic):
"""Collects the length of a subfield in a field, e.g. to collect the number of arguments of
N-ary relations."""

def __init__(self, field: str, subfield: str, **kwargs):
super().__init__(**kwargs)
self.field = field
self.subfield = subfield

def _collect(self, doc: Document) -> List[int]:
field_obj = getattr(doc, self.field)
lengths = []
for entry in field_obj:
subfield_obj = getattr(entry, self.subfield)
lengths.append(len(subfield_obj))
return lengths


class SpanLengthCollector(DocumentStatistic):
"""Collects the lengths of Span annotations. If labels are provided, the lengths collected per
label.
Expand Down Expand Up @@ -184,64 +120,3 @@ def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]:
values[label].append(length)

return values if self.labels is not None else values["ALL"]


class DummyCollector(DocumentStatistic):
"""A dummy collector that always returns 1, e.g. to count the number of documents.
Can be used to count the number of documents.
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["sum"]

def _collect(self, doc: Document) -> int:
return 1


class LabelCountCollector(DocumentStatistic):
"""Collects the number of field entries per label, e.g. to collect the number of entities per
type.
The field should be a list of elements with a label attribute.
Important: To make correct use of the result data, missing values need to be filled with 0, e.g.:
{("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]}
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"]

def __init__(
self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs
):
super().__init__(**kwargs)
self.field = field
self.label_attribute = label_attribute
if not (isinstance(labels, list) or labels == "INFERRED"):
raise ValueError("labels must be a list of strings or 'INFERRED'")
if labels == "INFERRED":
logger.warning(
f"Inferring labels with {self.__class__.__name__} from data produces wrong results "
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values "
f"are not included in the calculation. We remove these aggregation functions from "
f"this collector, but be aware that the results may be wrong for your own aggregation "
f"functions that rely on zero values."
)
self.aggregation_functions: Dict[str, Callable[[List], Any]] = {
name: func
for name, func in self.aggregation_functions.items()
if name not in ["mean", "std", "min"]
}

self.labels = labels

def _collect(self, doc: Document) -> Dict[str, int]:
field_obj = getattr(doc, self.field)
counts: Dict[str, int]
if self.labels == "INFERRED":
counts = defaultdict(int)
else:
counts = {label: 0 for label in self.labels}
for elem in field_obj:
label = getattr(elem, self.label_attribute)
counts[label] += 1
return dict(counts)
139 changes: 1 addition & 138 deletions tests/unit/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument

from pie_datasets import DatasetDict
from pie_datasets.statistics import (
DummyCollector,
FieldLengthCollector,
LabelCountCollector,
SpanLengthCollector,
SubFieldLengthCollector,
TokenCountCollector,
)
from pie_datasets.statistics import SpanLengthCollector
from tests import FIXTURES_ROOT


Expand All @@ -30,115 +23,6 @@ class Conll2003Document(TextBasedDocument):


def test_statistics(dataset):
statistic = DummyCollector()
values = statistic(dataset)
assert values == {"train": {"sum": 3}, "test": {"sum": 3}, "validation": {"sum": 3}}

statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG", "MISC"])
values = statistic(dataset)
assert values == {
"train": {
"LOC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"MISC": {
"mean": 0.6666666666666666,
"std": 0.9428090415820634,
"min": 0,
"max": 2,
"len": 3,
"sum": 2,
},
},
"validation": {
"LOC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"MISC": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
},
"test": {
"LOC": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"PER": {
"mean": 0.6666666666666666,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 2,
},
"ORG": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
"MISC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
},
}

statistic = LabelCountCollector(field="entities", labels="INFERRED")
values = statistic(dataset)
assert values == {
"train": {
"ORG": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 2, "len": 1, "sum": 2},
"PER": {"max": 1, "len": 1, "sum": 1},
"LOC": {"max": 1, "len": 1, "sum": 1},
},
"validation": {
"ORG": {"max": 2, "len": 2, "sum": 3},
"LOC": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 1, "len": 1, "sum": 1},
"PER": {"max": 1, "len": 1, "sum": 1},
},
"test": {"LOC": {"max": 2, "len": 2, "sum": 3}, "PER": {"max": 1, "len": 2, "sum": 2}},
}

statistic = FieldLengthCollector(field="text")
values = statistic(dataset)
assert values == {
"test": {"max": 57, "mean": 36.0, "min": 11, "std": 18.991226044325487},
"train": {"max": 48, "mean": 27.333333333333332, "min": 15, "std": 14.70449666674185},
"validation": {"max": 187, "mean": 89.66666666666667, "min": 17, "std": 71.5603863103665},
}

statistic = SpanLengthCollector(layer="entities")
values = statistic(dataset)
assert values == {
Expand Down Expand Up @@ -177,29 +61,8 @@ def test_statistics(dataset):
},
}

# this is not super useful, we just collect the lengths of the labels, but it is enough to test the code
statistic = SubFieldLengthCollector(field="entities", subfield="label")
values = statistic(dataset)
assert values == {
"test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0},
"train": {"max": 4, "mean": 3.4, "min": 3, "std": 0.4898979485566356},
"validation": {"max": 4, "mean": 3.1666666666666665, "min": 3, "std": 0.3726779962499649},
}


def test_statistics_with_tokenize(dataset):
statistic = TokenCountCollector(
text_field="text",
tokenizer="bert-base-uncased",
tokenizer_kwargs=dict(add_special_tokens=False),
)
values = statistic(dataset)
assert values == {
"test": {"max": 12, "mean": 9.333333333333334, "min": 4, "std": 3.7712361663282534},
"train": {"max": 9, "mean": 5.666666666666667, "min": 2, "std": 2.8674417556808756},
"validation": {"max": 38, "mean": 18.333333333333332, "min": 6, "std": 14.055445761538678},
}

@dataclasses.dataclass
class TokenDocumentWithLabeledEntities(TokenBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens")
Expand Down

0 comments on commit f2925d1

Please sign in to comment.