Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Various fixes and improvements
Browse files Browse the repository at this point in the history
- Remove unused statements
- Remove 'CREATE_SQL_STATEMENTS_FILE' as config attribute
- Fix bug in cpe_search with not getting all results
- Update test cases
- Refactor and improve code
  • Loading branch information
MRuppDev committed Mar 14, 2024
1 parent c65d7fa commit 4325e54
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 52 deletions.
1 change: 0 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"DATABASE_NAME": "cpe-search-dictionary.db3",
"DEPRECATED_CPES_FILE": "deprecated-cpes.json",
"CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json",
"NVD_API_KEY": "",
"DATABASE": {
"TYPE": "sqlite"
Expand Down
1 change: 0 additions & 1 deletion config_mariadb.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"DATABASE_NAME": "cpe_search_dictionary",
"DEPRECATED_CPES_FILE": "deprecated-cpes.json",
"CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json",
"NVD_API_KEY": "",
"DATABASE":{
"TYPE": "mariadb",
Expand Down
77 changes: 40 additions & 37 deletions cpe_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@
import sys
import time

try:
from database_wrapper_functions import *
except:
from .database_wrapper_functions import *

try: # use ujson if available
import ujson as json
except ModuleNotFoundError:
import json

# direct import when run as standalone script and relative import otherwise
try:
from database_wrapper_functions import *
except:
from .database_wrapper_functions import *


# Constants
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
CPE_API_URL = "https://services.nvd.nist.gov/rest/json/cpes/2.0/"
DEFAULT_CONFIG_FILE = os.path.join(SCRIPT_DIR, 'config.json')
CREATE_SQL_STATEMENTS_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'create_sql_statements.json')
DB_URI, DB_CONN_MEM = 'file:cpedb?mode=memory&cache=shared', None
CONNECTION_POOL_SIZE = os.cpu_count() # should be equal to number of cpu cores? (https://dba.stackexchange.com/a/305726)
TEXT_TO_VECTOR_RE = re.compile(r"[\w+\.]+")
Expand Down Expand Up @@ -83,6 +85,10 @@ def load_config_dict(_dict, db_type):
return config


def is_safe_database_name(db_name):
return all([c.isalnum() or c in ('-', '_') for c in db_name])


async def api_request(headers, params, requestno):
'''Perform request to API for one task'''

Expand Down Expand Up @@ -288,24 +294,25 @@ async def update(nvd_api_key=None, config=None):

db_type = config['DATABASE']['TYPE']
db_name = config['DATABASE_NAME']
if os.path.isfile(db_name):
os.remove(db_name)
if db_type == 'sqlite':
if os.path.isfile(db_name):
os.remove(db_name)
os.makedirs(os.path.dirname(db_name), exist_ok=True)

# get connection
try:
db_conn = get_database_connection(config['DATABASE'], db_name)
except:
db_conn = get_database_connection(config['DATABASE'], '')
db_cursor = db_conn.cursor()

if db_type == 'mariadb':
if not is_safe_database_name(db_name):
print('Potential malicious database name detected. Abort creation of database')
return False
db_conn = get_database_connection(config['DATABASE'], '')
db_cursor = db_conn.cursor()
db_cursor.execute(f'CREATE OR REPLACE DATABASE {db_name};')
db_cursor.execute(f'use {db_name};')
else:
db_conn = get_database_connection(config['DATABASE'], db_name)
db_cursor = db_conn.cursor()

# create tables
with open(config['CREATE_SQL_STATEMENTS_FILE']) as f:
with open(CREATE_SQL_STATEMENTS_FILE) as f:
create_sql_statements = json.loads(f.read())
db_cursor.execute(create_sql_statements['TABLES']['CPE_ENTRIES'][db_type])
db_cursor.execute(create_sql_statements['TABLES']['TERMS_TO_ENTRIES'][db_type])
Expand Down Expand Up @@ -597,32 +604,30 @@ def _search_cpes(queries_raw, count, threshold, keep_data_in_memory=False, confi
all_cpe_entry_ids.append(int(eid))

# iterate over all retrieved CPE infos and find best matching CPEs for queries
iterator = []
all_cpe_infos = []
# limiting number of max_results_per_query boosts performance of MariaDB
max_results_per_query = 1000
remaining = len(all_cpe_entry_ids)
is_one_iter_enough = remaining <= max_results_per_query

db_cursor = conn.cursor()
while remaining > 0:
db_cursor = conn.cursor()
if remaining > max_results_per_query:
count_params_in_str = max_results_per_query
else:
count_params_in_str = remaining
count_params_in_str = min(remaining, max_results_per_query)
param_in_str = ('?,' * count_params_in_str)[:-1]
if keep_data_in_memory or not is_one_iter_enough:
db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str
db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining])
cpe_infos = []
if db_cursor:
cpe_infos = db_cursor.fetchall()
iterator += cpe_infos
else:
db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str
db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining])
iterator = db_cursor

db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str
db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining])
cpe_infos = []
if db_cursor:
cpe_infos = db_cursor.fetchall()
all_cpe_infos += cpe_infos
remaining -= max_results_per_query

for cpe_info in iterator:
# same order needed for test repeatability
if os.environ.get('IS_CPE_SEARCH_TEST', 'false') == 'true':
all_cpe_infos = sorted(all_cpe_infos)

for cpe_info in all_cpe_infos:
cpe, cpe_tf, cpe_abs = cpe_info
cpe_tf = json.loads(cpe_tf)
cpe_abs = float(cpe_abs)
Expand Down Expand Up @@ -816,12 +821,10 @@ def get_all_cpes(keep_data_in_memory=False, config=None):
config = _load_config()

if keep_data_in_memory:
init_memdb()
conn = sqlite3.connect(DB_URI, uri=True)
db_cursor = conn.cursor()
conn = init_memdb(config)
else:
conn = get_database_connection(config['DATABASE'], config['DATABASE_NAME'], uri=True)
db_cursor = conn.cursor()
db_cursor = conn.cursor()

db_cursor.execute('SELECT cpe FROM cpe_entries')
cpes = [cpe[0] for cpe in db_cursor]
Expand Down
26 changes: 13 additions & 13 deletions database_wrapper_functions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import sqlite3
try: # only use mariadb module if installed
import mariadb
except:
except ImportError:
pass

SERVERLESS_DATABASES = ['sqlite']

def get_database_connection(config_database_keys, database_name, uri=False):
'''Return a database connection object, initialized with the given config'''
database_type = config_database_keys['TYPE']

db_conn = None
match database_type:
case 'sqlite':
db_conn = sqlite3.connect(database_name, uri=uri)
case 'mariadb':
db_conn = mariadb.connect(
user=config_database_keys['USER'],
password=config_database_keys['PASSWORD'],
host=config_database_keys['HOST'],
port=config_database_keys['PORT'],
database=database_name
)
if database_type == 'sqlite':
db_conn = sqlite3.connect(database_name, uri=uri)
elif 'mariadb':
db_conn = mariadb.connect(
user=config_database_keys['USER'],
password=config_database_keys['PASSWORD'],
host=config_database_keys['HOST'],
port=config_database_keys['PORT'],
database=database_name
)
else:
raise(Exception('Invalid database type %s given' % (database_type)))
return db_conn
2 changes: 2 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import os
import unittest
from cpe_search import search_cpes

Expand Down Expand Up @@ -106,4 +107,5 @@ def test_search_datatables_194(self):


if __name__ == '__main__':
os.environ['IS_CPE_SEARCH_TEST'] = 'true'
unittest.main()

0 comments on commit 4325e54

Please sign in to comment.