Skip to content

Commit

Permalink
Document embedding script and widget
Browse files Browse the repository at this point in the history
  • Loading branch information
djukicn committed Mar 19, 2020
1 parent ecb7162 commit d90a5d1
Show file tree
Hide file tree
Showing 2 changed files with 404 additions and 0 deletions.
170 changes: 170 additions & 0 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""This module contains classes used for embedding documents
into a vector space.
"""
import zlib
import base64
import json
import sys
import warnings
from typing import Tuple, Any, Optional
import numpy as np

from Orange.misc.server_embedder import ServerEmbedderCommunicator
from orangecontrib.text import Corpus


AGGREGATORS = ['mean', 'sum', 'max', 'min']
LANGS_TO_ISO = {'English': 'en', 'Slovenian': 'sl', 'German': 'de'}
LANGUAGES = list(LANGS_TO_ISO.values())

class PretrainedEmbedder:
"""This class is used for obtaining dense embeddings of documents in
corpus using fastText pretrained models from:
E. Grave, P. Bojanowski, P. Gupta, A. Joulin, T. Mikolov,
Learning Word Vectors for 157 Languages.
Proceedings of the International Conference on Language Resources and Evaluation, 2018.
Embedding is performed on server so the internet connection is a
prerequisite for using the class. Currently supported languages are:
- English (en)
- Slovenian (sl)
- German (de)
Attributes
----------
language : str
ISO 639-1 (two-letter) code of desired language.
aggregator : str
Aggregator which creates document embedding (single
vector) from word embeddings (multiple vectors).
Allowed values are mean, sum, max, min.
"""

def __init__(self, language: str = 'en',
aggregator: str = 'mean') -> None:
lang_error = '{} is not a valid language. Allowed values: {}'
agg_error = '{} is not a valid aggregator. Allowed values: {}'
if language.lower() not in LANGUAGES:
raise ValueError(lang_error.format(language, ', '.join(LANGUAGES)))
self.language = language.lower()
if aggregator.lower() not in AGGREGATORS:
raise ValueError(agg_error.format(aggregator, ', '.join(AGGREGATORS)))
self.aggregator = aggregator.lower()

self._dim = 300
self._embedder = _ServerEmbedder(self.aggregator,
model_name=self.language,
max_parallel_requests=1,
server_url='',
# TODO set proper url
embedder_type='text')

def __call__(self, corpus: Corpus, copy: bool = True,
processed_callback=None) -> Corpus:
"""Adds matrix of document embeddings to a corpus.
Parameters
----------
corpus : Corpus
Corpus on which transform is performed.
copy : bool
If set to True, a copy of corpus is made.
Returns
-------
Corpus
Corpus (original or a copy) with new features added.
Raises
------
ValueError
If corpus is not instance of Corpus.
RuntimeError
If document in corpus is larger than
50 KB after compression.
"""
if not isinstance(corpus, Corpus):
raise ValueError("Input should be instance of Corpus.")
corpus = corpus.copy() if copy else corpus
embs = self._embedder.embedd_data(corpus.tokens,
processed_callback=processed_callback)
# Check if some documents in corpus in weren't embedded
# for some reason. This is a very rare case.
warnings.simplefilter('always', RuntimeWarning)
for i, emb in enumerate(embs):
if emb is None:
embs[i] = np.zeros(self._dim) * np.nan
warnings.warn(("Some documents were not embedded for " +
"unknown reason. Those documents " +
"are represented as vectors of nans."),
RuntimeWarning)
variable_attrs = {
'hidden': True,
'skip-normalization': True,
'dense-embedding-feature': True
}

corpus.extend_attributes(np.array(embs),
['Dim{}'.format(i) for i in range(self._dim)],
var_attrs=variable_attrs)
return corpus

def report(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
"""Reports on current parameters of PretrainedEmbedder.
Returns
-------
tuple
Tuple of parameters.
"""
return (('Language', self.language),
('Aggregator', self.aggregator))

def set_cancelled(self):
"""Cancels current embedding process"""
if self._embedder:
self._embedder.set_cancelled()

def clear_cache(self):
"""Clears embedder cache"""
if self._embedder:
self._embedder.clear_cache()

def __enter__(self):
return self

def __exit__(self, ex_type, value, traceback):
self.set_cancelled()

def __del__(self):
self.__exit__(None, None, None)


class _ServerEmbedder(ServerEmbedderCommunicator):
def __init__(self, aggregator: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.content_type = 'application/json'
self.aggregator = aggregator

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
data_string = json.dumps(data_instance)
data = base64.b64encode(zlib.compress(
data_string.encode('utf-8', 'replace'),
level=-1)).decode('utf-8', 'replace')

if sys.getsizeof(data) > 50000:
raise RuntimeError("Document in corpus is too large. \
Size limit is 50 KB (after compression).")

data_dict = {
"data": data,
"aggregator": self.aggregator
}

json_string = json.dumps(data_dict)
return json_string.encode('utf-8', 'replace')

if __name__ == '__main__':
with PretrainedEmbedder(language='en', aggregator='max') as embedder:
embedder.clear_cache()
embedder(Corpus.from_file('deerwester'))
Loading

0 comments on commit d90a5d1

Please sign in to comment.