Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a crash you can run into with postgres array columns. #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions piicatcher_spacy/detectors/spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,35 @@
@register_detector
class SpacyDetector(DatumDetector):
pii_cls_map = {
'FAC': Address, # Buildings, airports, highways, bridges, etc.
'GPE': Address, # Countries, cities, states.
'LOC': Address, # Non-GPE locations, mountain ranges, bodies of water.
'PERSON': Person, # People, including fictional.
'PER': Person, # Bug in french model
'DATE': BirthDate, # Dates within the period 18 to 100 years ago.
"FAC": Address, # Buildings, airports, highways, bridges, etc.
"GPE": Address, # Countries, cities, states.
"LOC": Address, # Non-GPE locations, mountain ranges, bodies of water.
"PERSON": Person, # People, including fictional.
"PER": Person, # Bug in french model
"DATE": BirthDate, # Dates within the period 18 to 100 years ago.
}
name = 'DatumSpacyDetector'
name = "DatumSpacyDetector"

def __init__(self, model: str = "en_core_web_md"):
super(SpacyDetector, self).__init__()

# Fixes a warning message from transformers that is pulled in via spacy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.check_spacy_version()

if not self.check_spacy_model(model):
raise ValueError("Unable to find spacy model '{}'. Is your language supported? "
"Check the list of models available here: "
"https://github.com/explosion/spacy-models ".format(self.model))
raise ValueError(
"Unable to find spacy model '{}'. Is your language supported? "
"Check the list of models available here: "
"https://github.com/explosion/spacy-models ".format(self.model)
)

self.nlp = spacy.load(model)

# If the model doesn't support named entity recognition
if 'ner' not in [step[0] for step in self.nlp.pipeline]:
if "ner" not in [step[0] for step in self.nlp.pipeline]:
raise ValueError(
"The spacy model '{}' doesn't support named entity recognition, "
"please choose another model.".format(self.model)
Expand All @@ -49,16 +52,28 @@ def __init__(self, model: str = "en_core_web_md"):
@staticmethod
def check_spacy_version() -> bool:
"""Ensure that the version of spaCy is v3."""
spacy_version = spacy.__version__ # spacy_info.get('spaCy version', spacy_info.get('spacy_version', None))
spacy_version = (
spacy.__version__
) # spacy_info.get('spaCy version', spacy_info.get('spacy_version', None))

if spacy_version is None:
raise ImportError('Spacy v3 needs to be installed. Unable to detect spacy version.')
raise ImportError(
"Spacy v3 needs to be installed. Unable to detect spacy version."
)
try:
spacy_major = int(spacy_version.split('.')[0])
spacy_major = int(spacy_version.split(".")[0])
except Exception:
raise ImportError('Spacy v3 needs to be installed. Spacy version {} is unknown.'.format(spacy_version))
raise ImportError(
"Spacy v3 needs to be installed. Spacy version {} is unknown.".format(
spacy_version
)
)
if spacy_major != 3:
raise ImportError('Spacy v3 needs to be installed. Detected version {}.'.format(spacy_version))
raise ImportError(
"Spacy v3 needs to be installed. Detected version {}.".format(
spacy_version
)
)

return True

Expand All @@ -67,10 +82,12 @@ def check_spacy_model(model) -> bool:
"""Ensure that the spaCy model is installed."""
spacy_info = spacy.info()
if isinstance(spacy_info, str):
raise ValueError('Unable to detect spacy models.')
models = list(spacy_info.get('pipelines', spacy_info.get('models', None)).keys())
raise ValueError("Unable to detect spacy models.")
models = list(
spacy_info.get("pipelines", spacy_info.get("models", None)).keys()
)
if models is None:
raise ValueError('Unable to detect spacy models.')
raise ValueError("Unable to detect spacy models.")

if model not in models:
LOGGER.info("Downloading spacy model {}".format(model))
Expand All @@ -83,6 +100,12 @@ def check_spacy_model(model) -> bool:
return model in models

def detect(self, column: CatColumn, datum: str) -> Optional[PiiType]:
# if datum can be a list, like in postgres arrays, recurse, otherwise we crash in language.py from spacy.
if isinstance(datum, list):
for d in datum:
result = self.detect(column, d)
if result:
return result
doc = self.nlp(datum)
for ent in doc.ents:
LOGGER.debug("Found %s", ent.label_)
Expand Down