Skip to content

Commit

Permalink
Merge pull request #17 from DoodleBears/I16-fix-possible-detection-list
Browse files Browse the repository at this point in the history
  • Loading branch information
DoodleBears authored Jul 13, 2024
2 parents 05e358a + 4fee976 commit f429ed7
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 87 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def read(*relpath):
packages=find_packages(),
install_requires=[
"fast_langdetect",
"lingua-language-detector",
"pydantic",
"budoux",
"wordfreq",
"wordfreq[cjk]",
],
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
121 changes: 79 additions & 42 deletions split-lang-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -27,21 +27,9 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\admin\\.conda\\envs\\melotts\\lib\\site-packages\\wtpsplit\\__init__.py:45: DeprecationWarning: You are using WtP, the old sentence segmentation model. It is highly encouraged to use SaT instead due to strongly improved performance and efficiency. See https://github.com/segment-any-text/wtpsplit for more info. To ignore this warning, set ignore_legacy_warning=True.\n",
" warnings.warn(\n",
"c:\\Users\\admin\\.conda\\envs\\melotts\\lib\\site-packages\\sklearn\\base.py:376: InconsistentVersionWarning: Trying to unpickle estimator LogisticRegression from version 1.2.2 when using version 1.5.0. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"from wtpsplit import SaT, WtP\n",
"sat = SaT(\"sat-1l-sm\")\n",
Expand All @@ -52,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -61,14 +49,14 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:00<00:00, 166.68it/s]\n"
"100%|██████████| 1/1 [00:00<00:00, 124.43it/s]\n"
]
},
{
Expand All @@ -77,7 +65,7 @@
"['你', '喜欢看', 'アニメ', '吗']"
]
},
"execution_count": 14,
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -88,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -183,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -225,7 +213,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 86,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -372,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -386,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 88,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -460,69 +448,118 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\admin\\.conda\\envs\\melotts\\lib\\site-packages\\wtpsplit\\__init__.py:45: DeprecationWarning: You are using WtP, the old sentence segmentation model. It is highly encouraged to use SaT instead due to strongly improved performance and efficiency. See https://github.com/segment-any-text/wtpsplit for more info. To ignore this warning, set ignore_legacy_warning=True.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Vielen ', 'Dank ', 'merci ', 'beaucoup ', 'for your help.']\n",
"de:Vielen |nl:Dank |fr:merci |fr:beaucoup |en:for your help.|\n",
"0.5008478164672852\n",
"de:Vielen|nl:Dank|fr:merci|fr:beaucoup|en:for|en:your|en:help.|\n",
"['Ich ', 'bin müde ', 'je suis fatigué ', 'and ', 'I ', 'need some rest']\n",
"\n",
"0.8507270812988281\n",
"['Ich', 'bin', 'müde', 'je', 'suis', 'fatigué', 'and', 'I', 'need', 'some', 'rest']\n",
"Ich: ['de:1.0']\n",
"\n",
"bin: ['en:0.394521027803421', 'id:0.21292246878147125', 'tr:0.15927168726921082', 'ms:0.09677533060312271', 'eo:0.030221346765756607', 'jv:0.023466553539037704', 'sq:0.013604077510535717', 'sv:0.012493844144046307']\n",
"\n",
"müde: ['de:0.9626638293266296', 'tr:0.026752416044473648']\n",
"\n",
"je: ['sr:0.8350609540939331', 'fr:0.15938909351825714']\n",
"\n",
"suis: ['fr:0.9970543384552002']\n",
"\n",
"fatigué: ['fr:0.7745229601860046', 'es:0.14570455253124237', 'nl:0.03426196426153183', 'de:0.014598727226257324']\n",
"\n",
"and: ['en:0.9983668923377991']\n",
"\n",
"I: ['en:0.9979130625724792']\n",
"\n",
"need: ['en:0.9885499477386475']\n",
"\n",
"some: ['en:0.993599534034729']\n",
"\n",
"rest: ['en:0.46527764201164246', 'nl:0.05554379150271416', 'ja:0.039999693632125854', 'pl:0.033725958317518234', 'ar:0.028888104483485222', 'sv:0.028634123504161835', 'no:0.024481678381562233', 'id:0.022417806088924408', 'tr:0.02039233408868313', 'zh:0.014322302304208279']\n",
"\n",
"\n",
"0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\admin\\.conda\\envs\\melotts\\lib\\site-packages\\sklearn\\base.py:376: InconsistentVersionWarning: Trying to unpickle estimator LogisticRegression from version 1.2.2 when using version 1.5.0. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
}
],
"source": [
"from wtpsplit import WtP\n",
"from time import time\n",
"from datetime import datetime\n",
"text = \"Vielen Dank merci beaucoup for your help.\"\n",
"\n",
"# text = \"Vielen Dank merci beaucoup for your help.\"\n",
"text = \"Ich bin müde je suis fatigué and I need some rest\"\n",
"# text = \"日语使用者应超过一亿三千万人\"\n",
"# text = \"我是 VGroupChatBot,一个旨在支持多人通信的助手,通过可视化消息来帮助团队成员更好地交流。我可以帮助团队成员更好地整理和共享信息,特别是在讨论、会议和Brainstorming等情况下。你好我的名字是西野くまですmy name is bob很高兴认识你どうぞよろしくお願いいたします「こんにちは」是什么意思。我的名字是西野くまです。I am from Tokyo, 日本の首都。今天的天气非常好\"\n",
"time1 = datetime.now().timestamp()\n",
"wtp = WtP('wtp-bert-mini')\n",
"substrings = wtp.split(text_or_texts=text, threshold=1e-4)\n",
"substrings = wtp.split(text_or_texts=text, threshold=2e-3)\n",
"print(substrings)\n",
"for substring in substrings:\n",
" # lang = lingua_lang_detect_all(substring)\n",
" lang = fast_lang_detect(substring)\n",
" print(f\"{lang}:{substring}\",end='|')\n",
" lang = lingua_lang_detect_all(substring)\n",
" # print(f\"{lang}:{substring}\",end='|')\n",
"print()\n",
"time2 = datetime.now().timestamp()\n",
"\n",
"print(time2 - time1)\n",
"\n",
"from split_lang import LangSplitter\n",
"lang_splitter = LangSplitter()\n",
"substrings = lang_splitter._parse_without_zh_ja(text=text)\n",
"substrings = text.split(' ')\n",
"# substrings = lang_splitter._parse_zh_ja(text=text)\n",
"# substrings = lang_splitter._parse_without_zh_ja(text=text)\n",
"print(substrings)\n",
"for substring in substrings:\n",
" # lang = lingua_lang_detect_all(substring)\n",
" lang = fast_lang_detect(substring)\n",
" print(f\"{lang}:{substring}\",end='|')\n",
" lang = fast_langdetect.detect_multilingual(substring, low_memory=False, k=10, threshold=0.01)\n",
" lang_list = [f\"{item['lang']}:{item['score']}\" for item in lang]\n",
" print(f\"{substring}: {lang_list}\",end='\\n\\n')\n",
"time3 = datetime.now().timestamp()\n",
"print()\n",
"print(time3 - time2)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 117,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9.55e-05\n",
"7.82e-05\n",
"1.221227621483376\n"
"1.66e-05\n",
"0.00126\n",
"0.013174603174603174\n"
]
}
],
"source": [
"from wordfreq import word_frequency\n",
"ja_freq = word_frequency('日本人', 'ja')\n",
"zh_freq = word_frequency('日本人', 'zh')\n",
"ja_freq = word_frequency('bin ', 'en')\n",
"zh_freq = word_frequency('bin ', 'de')\n",
"print(ja_freq)\n",
"print(zh_freq)\n",
"print(ja_freq / zh_freq)"
Expand Down
2 changes: 1 addition & 1 deletion split_lang/detect_lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .detector import (
detect_lang_combined,
possible_detection_list,
is_word_freq_higher_in_ja,
is_word_freq_higher_in_lang_b,
)
40 changes: 16 additions & 24 deletions split_lang/detect_lang/detector.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
import logging
from typing import List
import fast_langdetect
from lingua import LanguageDetectorBuilder
from wordfreq import word_frequency

from ..model import LangSectionType
from ..split.utils import contains_ja


all_detector = (
LanguageDetectorBuilder.from_all_languages()
.with_preloaded_language_models()
.build()
)


logger = logging.getLogger(__name__)


Expand All @@ -24,13 +15,6 @@ def fast_lang_detect(text: str) -> str:
return result


def lingua_lang_detect_all(text: str) -> str:
language = all_detector.detect_language_of(text=text)
if language is None:
return "x"
return language.iso_code_639_1.name.lower()


# For example '衬衫' cannot be detected by `langdetect`, and `fast_langdetect` will detect it as 'en'
def detect_lang_combined(text: str, lang_section_type: LangSectionType) -> str:
if lang_section_type is LangSectionType.ZH_JA:
Expand All @@ -41,20 +25,28 @@ def detect_lang_combined(text: str, lang_section_type: LangSectionType) -> str:


def possible_detection_list(text) -> List[str]:
languages = []
languages.append(fast_lang_detect(text))
languages.append(lingua_lang_detect_all(text))
languages = [
item["lang"]
for item in fast_langdetect.detect_multilingual(
text,
low_memory=False,
k=5,
threshold=0.01,
)
]
return languages


def _detect_word_freq_in_lang(word: str, lang: str) -> float:
return word_frequency(word=word, lang=lang)


def is_word_freq_higher_in_ja(word: str) -> bool:
word_freq_ja = _detect_word_freq_in_lang(word=word, lang="ja")
word_freq_zh = _detect_word_freq_in_lang(word=word, lang="zh")
if word_freq_zh == 0:
def is_word_freq_higher_in_lang_b(word: str, lang_a: str, lang_b: str) -> bool:
if lang_a == "x" or lang_b == "x":
return False
word_freq_lang_b = _detect_word_freq_in_lang(word=word, lang=lang_b)
word_freq_lang_a = _detect_word_freq_in_lang(word=word, lang=lang_a)
if word_freq_lang_a == 0:
return False
# 0.8 means either is more frequently used in Japanese or in both language the word is frequently used
return (word_freq_ja / word_freq_zh) > 0.8
return (word_freq_lang_b / word_freq_lang_a) > 0.8
Loading

0 comments on commit f429ed7

Please sign in to comment.