Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Document Embedding: add SBERT #839

Merged
merged 4 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 21 additions & 46 deletions orangecontrib/text/tests/test_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
from collections.abc import Iterator
import asyncio

from orangecontrib.text.vectorization.sbert import (
SBERT,
MIN_CHUNKS,
MAX_PACKAGE_SIZE,
EMB_DIM
)
from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM
from orangecontrib.text import Corpus

PATCH_METHOD = 'httpx.AsyncClient.post'
Expand Down Expand Up @@ -37,47 +32,17 @@ async def dummy_post(url, headers, data):


class TestSBERT(unittest.TestCase):

def setUp(self):
self.sbert = SBERT()
self.sbert.clear_cache()
self.corpus = Corpus.from_file('deerwester')

def tearDown(self):
self.sbert.clear_cache()

def test_make_chunks_small(self):
chunks = self.sbert._make_chunks(
self.corpus.documents, [100] * len(self.corpus.documents)
)
self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS))

def test_make_chunks_medium(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
chunks = self.sbert._make_chunks(
documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents)
)
self.assertEqual(len(chunks), MIN_CHUNKS)

def test_make_chunks_large(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS * 100
mps = MAX_PACKAGE_SIZE
chunks = self.sbert._make_chunks(
documents,
[mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps]
)
self.assertGreater(len(chunks), MIN_CHUNKS)

@patch(PATCH_METHOD)
def test_empty_corpus(self, mock):
self.assertEqual(
len(self.sbert(self.corpus.documents[:0])), 0
)
self.assertEqual(len(self.sbert(self.corpus.documents[:0])), 0)
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(
Expand All @@ -95,14 +60,24 @@ def test_none_result(self):
result = self.sbert(self.corpus.documents)
self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None])

@patch(PATCH_METHOD, make_dummy_post(RESPONSE[0]))
def test_success_chunks(self):
num_docs = len(self.corpus.documents)
documents = self.corpus.documents
if num_docs < MIN_CHUNKS:
documents = [documents[0]] * MIN_CHUNKS
result = self.sbert(documents)
self.assertEqual(len(result), MIN_CHUNKS)
@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE)))
def test_transform(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertIsNone(skipped)
self.assertEqual(len(self.corpus), len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

@patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3)))
def test_transform_skipped(self):
res, skipped = self.sbert.transform(self.corpus)
self.assertEqual(len(self.corpus) - 1, len(res))
self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas)
self.assertEqual(384, len(res.domain.attributes))

self.assertEqual(1, len(skipped))
self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas)
self.assertEqual(0, len(skipped.domain.attributes))


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ def _transform(

return new_corpus, skipped_corpus

def report(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.

Returns
-------
tuple
Tuple of parameters.
"""
return (('Language', self.language),
('Aggregator', self.aggregator))
return (
("Embedder", "fastText"),
("Language", self.language),
("Aggregator", self.aggregator),
)

def clear_cache(self):
"""Clears embedder cache"""
Expand Down
164 changes: 93 additions & 71 deletions orangecontrib/text/vectorization/sbert.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import json
import base64
import warnings
import zlib
import sys
from typing import Any, List, Optional, Callable
from typing import Any, List, Optional, Callable, Tuple

import numpy as np

from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback

from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer

# maximum document size that we still send to the server
MAX_PACKAGE_SIZE = 3000000
# maximum size of a chunk - when one document is longer send is as a chunk with
# a single document
MAX_CHUNK_SIZE = 50000
MIN_CHUNKS = 20
EMB_DIM = 384


class SBERT:
class SBERT(BaseVectorizer):
def __init__(self) -> None:
self._server_communicator = _ServerCommunicator(
model_name='sbert',
model_name="sbert",
max_parallel_requests=100,
server_url='https://api.garaza.io',
embedder_type='text',
server_url="https://api.garaza.io",
embedder_type="text",
)

def __call__(
Expand All @@ -41,78 +40,101 @@ def __call__(
-------
An array of embeddings.
"""

if len(texts) == 0:
return []
# sort text by their lengths that longer texts start to embed first. It
# prevents that long text with long embedding times start embedding
# at the end and thus add extra time to the complete embedding time
sorted_texts = sorted(
enumerate(texts),
key=lambda x: len(x[1][0]) if x[1] is not None else 0,
reverse=True,
)
indices, sorted_texts = zip(*sorted_texts)
# embedd - send to server
results = self._server_communicator.embedd_data(sorted_texts, callback=callback)
# unsort and unpack
return [x[0] if x else None for _, x in sorted(zip(indices, results))]

def _transform(
self, corpus: Corpus, _, callback=dummy_callback
) -> Tuple[Corpus, Optional[Corpus]]:
"""
Computes embeddings for given corpus and append results to the corpus

skipped = list()

encoded_texts = list()
sizes = list()
chunks = list()
for i, text in enumerate(texts):
encoded = base64.b64encode(zlib.compress(
text.encode('utf-8', 'replace'), level=-1)
).decode('utf-8', 'replace')
size = sys.getsizeof(encoded)
if size > MAX_PACKAGE_SIZE:
skipped.append(i)
continue
encoded_texts.append(encoded)
sizes.append(size)

chunks = self._make_chunks(encoded_texts, sizes)

result_ = self._server_communicator.embedd_data(chunks, callback=callback)
if result_ is None:
return [None] * len(texts)

result = list()
assert len(result_) == len(chunks)
for res_chunk, orig_chunk in zip(result_, chunks):
if res_chunk is None:
# when embedder fails (Timeout or other error) result will be None
result.extend([None] * len(orig_chunk))
else:
result.extend(res_chunk)

results = list()
idx = 0
for i in range(len(texts)):
if i in skipped:
results.append(None)
else:
results.append(result[idx])
idx += 1

return results

def _make_chunks(self, encoded_texts, sizes, depth=0):
chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2)
chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2)
result = list()
for i in range(len(chunks)):
# checking that more than one text in chunk prevent recursion to infinity
# when one text is bigger than MAX_CHUNK_SIZE
if len(chunks[i]) > 1 and np.sum(chunk_sizes[i]) > MAX_CHUNK_SIZE:
result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1))
else:
result.append(chunks[i])
return [list(r) for r in result if len(r) > 0]
Parameters
----------
corpus
Corpus on which transform is performed.

Returns
-------
Embeddings
Corpus with new features added.
Skipped documents
Corpus of documents that were not embedded
"""
embs = self(corpus.documents, callback)

# Check if some documents in corpus in weren't embedded
# for some reason. This is a very rare case.
skipped_documents = [emb is None for emb in embs]
embedded_documents = np.logical_not(skipped_documents)

new_corpus = None
if np.any(embedded_documents):
# if at least one embedding is not None, extend attributes
new_corpus = corpus[embedded_documents]
new_corpus = new_corpus.extend_attributes(
np.array(
[e for e in embs if e],
dtype=float,
),
["Dim{}".format(i + 1) for i in range(EMB_DIM)],
var_attrs={
"embedding-feature": True,
"hidden": True,
},
)

skipped_corpus = None
if np.any(skipped_documents):
skipped_corpus = corpus[skipped_documents].copy()
skipped_corpus.name = "Skipped documents"
warnings.warn(
"Some documents were not embedded for unknown reason. Those "
"documents are skipped.",
RuntimeWarning,
)

return new_corpus, skipped_corpus

def report(self) -> Tuple[Tuple[str, str], ...]:
"""Reports on current parameters of DocumentEmbedder.

Returns
-------
tuple
Tuple of parameters.
"""
return (("Embedder", "Multilingual SBERT"),)

def clear_cache(self):
if self._server_communicator:
self._server_communicator.clear_cache()

def __enter__(self):
return self


class _ServerCommunicator(ServerEmbedderCommunicator):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.content_type = 'application/json'
self.content_type = "application/json"

async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
return json.dumps(data_instance).encode('utf-8', 'replace')
data = base64.b64encode(
zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)
).decode("utf-8", "replace")
if sys.getsizeof(data) > 500000:
# Document in corpus is too large. Size limit is 500 KB
# (after compression). - document skipped
return None
return json.dumps([data]).encode("utf-8", "replace")
Loading