Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Add weighing for cos sim to queries, similar to 9c6bd48
Browse files Browse the repository at this point in the history
  • Loading branch information
ra1nb0rn committed Jan 25, 2024
1 parent f06c357 commit 06825e7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
41 changes: 29 additions & 12 deletions cpe_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DB_URI, DB_CONN_MEM = 'file:cpedb?mode=memory&cache=shared', None
TEXT_TO_VECTOR_RE = re.compile(r"[\w+\.]+")
CPE_TERM_WEIGHT_EXP_FACTOR = -0.08
QUERY_TERM_WEIGHT_EXP_FACTOR = -0.25
GET_ALL_CPES_RE = re.compile(r'(.*);.*;.*')
VERSION_MATCH_ZE_RE = re.compile(r'\b([\d]+\.?){1,4}\b')
VERSION_MATCH_CPE_CREATION_RE = re.compile(r'\b((\d[\da-zA-Z\.]{0,6})([\+\-\.\_][\da-zA-Z\.]+){0,4})[^\w\n]*$')
Expand Down Expand Up @@ -382,19 +383,19 @@ def _get_alternative_queries(init_queries):
alt_queries_mapping[query].append(query + ' getbootstrap')

# check for different variants of js library names, e.g. 'moment.js' vs. 'momentjs' vs. 'moment js'
query_words = query.split()
if 'js ' in query or ' js' in query or query.endswith('js'):
words = query.split()
alt_queries = []
for i, word in enumerate(words):
for i, word in enumerate(query_words):
word = word.strip()
new_query_words1, new_query_words2 = [], []
if word == 'js' and i > 0:
new_query_words1 = words[:i-1] + [words[i-1] + 'js']
new_query_words2 = words[:i-1] + [words[i-1] + '.js']
new_query_words1 = query_words[:i-1] + [query_words[i-1] + 'js']
new_query_words2 = query_words[:i-1] + [query_words[i-1] + '.js']
elif word.endswith('.js') or word.endswith('js'):
if i > 0:
new_query_words1 += words[:i]
new_query_words2 += words[:i]
new_query_words1 += query_words[:i]
new_query_words2 += query_words[:i]
if word.endswith('.js'):
new_query_words1 += [word[:-len('.js')]] + ['js']
new_query_words2 += [word[:-len('.js')] + 'js']
Expand All @@ -403,9 +404,9 @@ def _get_alternative_queries(init_queries):
new_query_words2 += [word[:-len('js')] + '.js']

if new_query_words1:
if i < len(words) - 1:
new_query_words1 += words[i+1:]
new_query_words2 += words[i+1:]
if i < len(query_words) - 1:
new_query_words1 += query_words[i+1:]
new_query_words2 += query_words[i+1:]
alt_queries.append(' '.join(new_query_words1))
alt_queries.append(' '.join(new_query_words2))

Expand Down Expand Up @@ -452,7 +453,16 @@ def _get_alternative_queries(init_queries):
pot_alt_query = ' '.join(pot_alt_query_parts)

if pot_alt_query != query.strip():
alt_queries_mapping[query].append(pot_alt_query)
alt_queries_mapping[query].append(pot_alt_query) # w/o including orig query words
for word in query.split():
if word not in pot_alt_query:
pot_alt_query += ' ' + word
alt_queries_mapping[query].append(pot_alt_query) # w/ including orig query words

# add alt query in case likely subversion is split from main version by a space
if len(query_words) > 2 and len(query_words[-1]) < 7:
alt_queries_mapping[query].append(query + ' ' + query_words[-2] + query_words[-1])
alt_queries_mapping[query].append(' '.join(query_words[:-2]) + ' ' + query_words[-2] + query_words[-1])

# zero extend versions, e.g. 'Apache httpd 2.4' --> 'Apache httpd 2.4.0'
version_match = VERSION_MATCH_ZE_RE.search(query)
Expand Down Expand Up @@ -494,9 +504,16 @@ def _search_cpes(queries_raw, count, threshold, keep_data_in_memory=False, confi
most_similar = {}
all_query_words = set()
for query in queries:
query_tf = Counter(TEXT_TO_VECTOR_RE.findall(query))
words_query = TEXT_TO_VECTOR_RE.findall(query)
word_weights_query = {}
for i, word in enumerate(words_query):
if word not in word_weights_query:
word_weights_query[word] = math.exp(QUERY_TERM_WEIGHT_EXP_FACTOR * i)

# compute query's cosine vector for similarity comparison
query_tf = Counter(words_query)
for term, tf in query_tf.items():
query_tf[term] = tf / len(query_tf)
query_tf[term] = word_weights_query[term] * (tf / len(query_tf))
all_query_words |= set(query_tf.keys())
query_abs = math.sqrt(sum([cnt**2 for cnt in query_tf.values()]))
query_infos[query] = (query_tf, query_abs)
Expand Down
30 changes: 20 additions & 10 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_search_wp_572(self):
self.maxDiff = None
query = 'WordPress 5.7.2'
test_best_match_cpe = 'cpe:2.3:a:wordpress:wordpress:5.7.2:*:*:*:*:*:*:*'
test_best_match_score = 0.9686485306860276
test_best_match_score = 0.9919023838421461
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -18,7 +18,7 @@ def test_search_apache_2425(self):
self.maxDiff = None
query = 'Apache 2.4.25'
test_best_match_cpe = 'cpe:2.3:a:apache:http_server:2.4.25:*:*:*:*:*:*:*'
test_best_match_score = 0.7427124236405216
test_best_match_score = 0.7702103477897395
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -27,7 +27,7 @@ def test_search_proftpd_133c(self):
self.maxDiff = None
query = 'Proftpd 1.3.3c'
test_best_match_cpe = 'cpe:2.3:a:proftpd:proftpd:1.3.3:c:*:*:*:*:*:*'
test_best_match_score = 0.829017833421458
test_best_match_score = 0.9226616585163939
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -36,7 +36,7 @@ def test_search_thingsboard_341(self):
self.maxDiff = None
query = 'Thingsboard 3.4.1'
test_best_match_cpe = 'cpe:2.3:a:thingsboard:thingsboard:3.4.1:*:*:*:*:*:*:*'
test_best_match_score = 0.9686485306860276
test_best_match_score = 0.9919023838421461
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -45,7 +45,7 @@ def test_search_redis_323(self):
self.maxDiff = None
query = 'Redis 3.2.3'
test_best_match_cpe = 'cpe:2.3:a:redis:redis:3.2.3:*:*:*:*:*:*:*'
test_best_match_score = 0.9686485306860276
test_best_match_score = 0.9919023838421461
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -54,7 +54,7 @@ def test_search_piwik_045(self):
self.maxDiff = None
query = 'Piwik 0.4.5'
test_best_match_cpe = 'cpe:2.3:a:piwik:piwik:0.4.5:*:*:*:*:*:*:*'
test_best_match_score = 0.9686485306860276
test_best_match_score = 0.9919023838421461
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -63,7 +63,7 @@ def test_search_vmware_spring_framework_5326(self):
self.maxDiff = None
query = 'VMWare Spring Framework 5.3.26'
test_best_match_cpe = 'cpe:2.3:a:vmware:spring_framework:5.3.26:*:*:*:*:*:*:*'
test_best_match_score = 0.996033093730958
test_best_match_score = 0.9836819689304376
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -72,7 +72,7 @@ def test_search_zulip_48(self):
self.maxDiff = None
query = 'Zulip 4.8'
test_best_match_cpe = 'cpe:2.3:a:zulip:zulip:4.8:*:*:*:*:*:*:*'
test_best_match_score = 0.9686485306860276
test_best_match_score = 0.9919023838421461
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -81,7 +81,7 @@ def test_search_electron_1317(self):
self.maxDiff = None
query = 'Electron 13.1.7'
test_best_match_cpe = 'cpe:2.3:a:electronjs:electron:13.1.7:*:*:*:*:*:*:*'
test_best_match_score = 0.7817733882696567
test_best_match_score = 0.7796549134972258
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)
Expand All @@ -90,10 +90,20 @@ def test_search_blackice_agent_for_server_30(self):
self.maxDiff = None
query = 'BlackIce Agent for Server 3.0'
test_best_match_cpe = 'cpe:2.3:a:iss:blackice_agent_for_server:3.0:*:*:*:*:*:*:*'
test_best_match_score = 0.8665018147937851
test_best_match_score = 0.8503750787877568
result = search_cpes(queries=[query])
self.assertEqual(result[query][0][0], test_best_match_cpe)
self.assertAlmostEqual(result[query][0][1], test_best_match_score)

def test_search_datatables_194(self):
self.maxDiff = None
query = 'datatables 1.9.4'
expected_best_results = [('cpe:2.3:a:datatables:datatables.net:1.10.0:-:*:*:*:node.js:*:*', 0.43397267978578885), ('cpe:2.3:a:sprymedia:datatables:1.9.2:*:*:*:*:jquery:*:*', 0.40060727459547485)]
result = search_cpes(queries=[query])
for i in range(2):
self.assertEqual(result[query][i][0], expected_best_results[i][0])
self.assertAlmostEqual(result[query][i][1], expected_best_results[i][1])


if __name__ == '__main__':
unittest.main()

0 comments on commit 06825e7

Please sign in to comment.