-
-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document embedding script and widget
- Loading branch information
Showing
12 changed files
with
720 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
Document Embedding | ||
================== | ||
|
||
Embeds documents from input corpus into vector space by using pretrained | ||
[fastText](https://fasttext.cc/docs/en/crawl-vectors.html) models described in | ||
E. Grave, P. Bojanowski, P. Gupta, A. Joulin, T. Mikolov, | ||
Learning Word Vectors for 157 Languages. | ||
Proceedings of the International Conference on Language Resources and Evaluation, 2018. | ||
|
||
**Inputs** | ||
|
||
- Corpus: A collection of documents. | ||
|
||
**Outputs** | ||
|
||
- Corpus: Corpus with new features appended. | ||
|
||
**Document Embedding** parses ngrams of each document in corpus, obtains embedding | ||
for each ngram using pretrained model for chosen language and obtains one vector for each document by aggregating ngram embeddings using one of offered aggregators. Note that method will work on any ngrams but it will give best results if corpus is preprocessed such that ngrams are words (because model was trained to embed words). | ||
|
||
![](images/Document-Embedding-stamped.png) | ||
|
||
1. Widget parameters: | ||
- Language: widget will use a model trained on documents in chosen language. | ||
- Aggregator: operation to perform on ngram embeddings to aggregate them into a single document vector. | ||
2. Cancel current execution. | ||
3. If *Apply automatically* is checked, changes in parameters are sent automatically. Alternatively press *Apply*. | ||
|
||
Examples | ||
-------- | ||
|
||
In first example, we will inspect how the widget works. Load *book-excerpts.tab* using [Corpus](corpus-widget.md) widget and connect it to **Document Embedding**. Check the output data by connecting **Document Embedding** to **Data Table**. We see additional 300 features that we widget has appended. | ||
|
||
![](images/Document-Embedding-Example1.png) | ||
|
||
In the second example we will try to predict document category. We will keep working on *book-excerpts.tab* loaded with [Corpus](corpus-widget.md) widget and sent through [Preprocess Text](preprocesstext.md) with default parameters. Connect **Preprocess Text** to **Document Embedding** to obtain features for predictive modelling. Here we set aggregator to Sum. | ||
|
||
Connect **Document Embedding** to **Test and Score** and also connect learner of choice to the left side of **Test and Score**. We chose SVM and changed kernel to Linear. **Test and Score** will now compute performance of each learner on the input. We can see that we achieved great results. | ||
|
||
Let's now inspect confusion matrix. Connect **Test and Score** to **Confusion Matrix**. | ||
Clicking on *Select Misclassified* will output documents that were misclassified. We can further inspect them by connecting [Corpus Viewer](corpusviewer.md) to **Confusion Matrix**. | ||
|
||
![](images/Document-Embedding-Example2.png) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
import asyncio | ||
from numpy.testing import assert_array_equal | ||
|
||
from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder | ||
from orangecontrib.text import Corpus | ||
|
||
PATCH_METHOD = 'httpx.AsyncClient.post' | ||
|
||
|
||
class DummyResponse: | ||
|
||
def __init__(self, content): | ||
self.content = content | ||
|
||
|
||
def make_dummy_post(response, sleep=0): | ||
@staticmethod | ||
async def dummy_post(url, headers, data): | ||
await asyncio.sleep(sleep) | ||
return DummyResponse(content=response) | ||
return dummy_post | ||
|
||
|
||
class DocumentEmbedderTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.embedder = DocumentEmbedder() # default params | ||
self.corpus = Corpus.from_file('deerwester') | ||
|
||
def tearDown(self): | ||
self.embedder.clear_cache() | ||
|
||
@patch(PATCH_METHOD) | ||
def test_with_empty_corpus(self, mock): | ||
self.assertEqual(len(self.embedder(self.corpus[:0])), 0) | ||
mock.request.assert_not_called() | ||
mock.get_response.assert_not_called() | ||
self.assertEqual(self.embedder._embedder._cache._cache_dict, dict()) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_success_subset(self): | ||
res = self.embedder(self.corpus[[0]]) | ||
assert_array_equal(res.X, [[0.3, 1]]) | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_success_shapes(self): | ||
res = self.embedder(self.corpus) | ||
self.assertEqual(res.X.shape, (len(self.corpus), 2)) | ||
self.assertEqual(len(res.domain), len(self.corpus.domain) + 2) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'')) | ||
def test_empty_response(self): | ||
with self.assertWarns(RuntimeWarning): | ||
res = self.embedder(self.corpus[[0]]) | ||
self.assertEqual(res.X.shape, (0, 0)) | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'str')) | ||
def test_invalid_response(self): | ||
with self.assertWarns(RuntimeWarning): | ||
res = self.embedder(self.corpus[[0]]) | ||
self.assertEqual(res.X.shape, (0, 0)) | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embeddings": [0.3, 1]}')) | ||
def test_invalid_json_key(self): | ||
with self.assertWarns(RuntimeWarning): | ||
res = self.embedder(self.corpus[[0]]) | ||
self.assertEqual(res.X.shape, (0, 0)) | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_persistent_caching(self): | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
self.embedder(self.corpus[[0]]) | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) | ||
self.embedder._embedder._cache.persist_cache() | ||
|
||
self.embedder = DocumentEmbedder() | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) | ||
|
||
self.embedder.clear_cache() | ||
self.embedder = DocumentEmbedder() | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_cache_for_different_languages(self): | ||
embedder = DocumentEmbedder(language='sl') | ||
embedder.clear_cache() | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0) | ||
embedder(self.corpus[[0]]) | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) | ||
embedder._embedder._cache.persist_cache() | ||
|
||
self.embedder = DocumentEmbedder() | ||
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) | ||
self.embedder._embedder._cache.persist_cache() | ||
|
||
embedder = DocumentEmbedder(language='sl') | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) | ||
embedder.clear_cache() | ||
self.embedder.clear_cache() | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_cache_for_different_aggregators(self): | ||
embedder = DocumentEmbedder(aggregator='max') | ||
embedder.clear_cache() | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0) | ||
embedder(self.corpus[[0]]) | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) | ||
embedder._embedder._cache.persist_cache() | ||
|
||
embedder = DocumentEmbedder(aggregator='min') | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) | ||
embedder(self.corpus[[0]]) | ||
self.assertEqual(len(embedder._embedder._cache._cache_dict), 2) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_with_statement(self): | ||
with self.embedder as embedder: | ||
res = embedder(self.corpus[[0]]) | ||
assert_array_equal(res.X, [[0.3, 1]]) | ||
|
||
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) | ||
def test_cancel(self): | ||
self.assertFalse(self.embedder._embedder._cancelled) | ||
self.embedder._embedder._cancelled = True | ||
with self.assertRaises(Exception): | ||
self.embedder(self.corpus[[0]]) | ||
|
||
@patch(PATCH_METHOD, side_effect=OSError) | ||
def test_connection_error(self, _): | ||
embedder = DocumentEmbedder() | ||
with self.assertRaises(ConnectionError): | ||
embedder(self.corpus[[0]]) | ||
|
||
def test_invalid_parameters(self): | ||
with self.assertRaises(ValueError): | ||
self.embedder = DocumentEmbedder(language='eng') | ||
with self.assertRaises(ValueError): | ||
self.embedder = DocumentEmbedder(aggregator='average') | ||
|
||
def test_invalid_corpus_type(self): | ||
with self.assertRaises(ValueError): | ||
self.embedder(self.corpus[0]) |
Oops, something went wrong.