From 4325e54d99bd78d673cf996d485d8e02dfdb8ed7 Mon Sep 17 00:00:00 2001 From: mrupp Date: Wed, 13 Mar 2024 14:28:51 +0100 Subject: [PATCH] Various fixes and improvements - Remove unused statements - Remove 'CREATE_SQL_STATEMENTS_FILE' as config attribute - Fix bug in cpe_search with not getting all results - Update test cases - Refactor and improve code --- config.json | 1 - config_mariadb.json | 1 - cpe_search.py | 77 ++++++++++++++++++----------------- database_wrapper_functions.py | 26 ++++++------ test.py | 2 + 5 files changed, 55 insertions(+), 52 deletions(-) diff --git a/config.json b/config.json index 2008d4f..8391686 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,6 @@ { "DATABASE_NAME": "cpe-search-dictionary.db3", "DEPRECATED_CPES_FILE": "deprecated-cpes.json", - "CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json", "NVD_API_KEY": "", "DATABASE": { "TYPE": "sqlite" diff --git a/config_mariadb.json b/config_mariadb.json index 86c2c6e..1f7c6a5 100644 --- a/config_mariadb.json +++ b/config_mariadb.json @@ -1,7 +1,6 @@ { "DATABASE_NAME": "cpe_search_dictionary", "DEPRECATED_CPES_FILE": "deprecated-cpes.json", - "CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json", "NVD_API_KEY": "", "DATABASE":{ "TYPE": "mariadb", diff --git a/cpe_search.py b/cpe_search.py index 3dcf2e0..c571df8 100755 --- a/cpe_search.py +++ b/cpe_search.py @@ -10,21 +10,23 @@ import sys import time -try: - from database_wrapper_functions import * -except: - from .database_wrapper_functions import * - try: # use ujson if available import ujson as json except ModuleNotFoundError: import json +# direct import when run as standalone script and relative import otherwise +try: + from database_wrapper_functions import * +except: + from .database_wrapper_functions import * + # Constants SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) CPE_API_URL = "https://services.nvd.nist.gov/rest/json/cpes/2.0/" DEFAULT_CONFIG_FILE = os.path.join(SCRIPT_DIR, 'config.json') +CREATE_SQL_STATEMENTS_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'create_sql_statements.json') DB_URI, DB_CONN_MEM = 'file:cpedb?mode=memory&cache=shared', None CONNECTION_POOL_SIZE = os.cpu_count() # should be equal to number of cpu cores? (https://dba.stackexchange.com/a/305726) TEXT_TO_VECTOR_RE = re.compile(r"[\w+\.]+") @@ -83,6 +85,10 @@ def load_config_dict(_dict, db_type): return config +def is_safe_database_name(db_name): + return all([c.isalnum() or c in ('-', '_') for c in db_name]) + + async def api_request(headers, params, requestno): '''Perform request to API for one task''' @@ -288,24 +294,25 @@ async def update(nvd_api_key=None, config=None): db_type = config['DATABASE']['TYPE'] db_name = config['DATABASE_NAME'] - if os.path.isfile(db_name): - os.remove(db_name) if db_type == 'sqlite': + if os.path.isfile(db_name): + os.remove(db_name) os.makedirs(os.path.dirname(db_name), exist_ok=True) - # get connection - try: - db_conn = get_database_connection(config['DATABASE'], db_name) - except: - db_conn = get_database_connection(config['DATABASE'], '') - db_cursor = db_conn.cursor() - if db_type == 'mariadb': + if not is_safe_database_name(db_name): + print('Potential malicious database name detected. Abort creation of database') + return False + db_conn = get_database_connection(config['DATABASE'], '') + db_cursor = db_conn.cursor() db_cursor.execute(f'CREATE OR REPLACE DATABASE {db_name};') db_cursor.execute(f'use {db_name};') + else: + db_conn = get_database_connection(config['DATABASE'], db_name) + db_cursor = db_conn.cursor() # create tables - with open(config['CREATE_SQL_STATEMENTS_FILE']) as f: + with open(CREATE_SQL_STATEMENTS_FILE) as f: create_sql_statements = json.loads(f.read()) db_cursor.execute(create_sql_statements['TABLES']['CPE_ENTRIES'][db_type]) db_cursor.execute(create_sql_statements['TABLES']['TERMS_TO_ENTRIES'][db_type]) @@ -597,32 +604,30 @@ def _search_cpes(queries_raw, count, threshold, keep_data_in_memory=False, confi all_cpe_entry_ids.append(int(eid)) # iterate over all retrieved CPE infos and find best matching CPEs for queries - iterator = [] + all_cpe_infos = [] + # limiting number of max_results_per_query boosts performance of MariaDB max_results_per_query = 1000 remaining = len(all_cpe_entry_ids) is_one_iter_enough = remaining <= max_results_per_query + + db_cursor = conn.cursor() while remaining > 0: - db_cursor = conn.cursor() - if remaining > max_results_per_query: - count_params_in_str = max_results_per_query - else: - count_params_in_str = remaining + count_params_in_str = min(remaining, max_results_per_query) param_in_str = ('?,' * count_params_in_str)[:-1] - if keep_data_in_memory or not is_one_iter_enough: - db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str - db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining]) - cpe_infos = [] - if db_cursor: - cpe_infos = db_cursor.fetchall() - iterator += cpe_infos - else: - db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str - db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining]) - iterator = db_cursor + db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str + db_cursor.execute(db_query, all_cpe_entry_ids[remaining-count_params_in_str:remaining]) + cpe_infos = [] + if db_cursor: + cpe_infos = db_cursor.fetchall() + all_cpe_infos += cpe_infos remaining -= max_results_per_query - for cpe_info in iterator: + # same order needed for test repeatability + if os.environ.get('IS_CPE_SEARCH_TEST', 'false') == 'true': + all_cpe_infos = sorted(all_cpe_infos) + + for cpe_info in all_cpe_infos: cpe, cpe_tf, cpe_abs = cpe_info cpe_tf = json.loads(cpe_tf) cpe_abs = float(cpe_abs) @@ -816,12 +821,10 @@ def get_all_cpes(keep_data_in_memory=False, config=None): config = _load_config() if keep_data_in_memory: - init_memdb() - conn = sqlite3.connect(DB_URI, uri=True) - db_cursor = conn.cursor() + conn = init_memdb(config) else: conn = get_database_connection(config['DATABASE'], config['DATABASE_NAME'], uri=True) - db_cursor = conn.cursor() + db_cursor = conn.cursor() db_cursor.execute('SELECT cpe FROM cpe_entries') cpes = [cpe[0] for cpe in db_cursor] diff --git a/database_wrapper_functions.py b/database_wrapper_functions.py index fbb5f57..9347256 100644 --- a/database_wrapper_functions.py +++ b/database_wrapper_functions.py @@ -1,25 +1,25 @@ import sqlite3 try: # only use mariadb module if installed import mariadb -except: +except ImportError: pass -SERVERLESS_DATABASES = ['sqlite'] def get_database_connection(config_database_keys, database_name, uri=False): '''Return a database connection object, initialized with the given config''' database_type = config_database_keys['TYPE'] db_conn = None - match database_type: - case 'sqlite': - db_conn = sqlite3.connect(database_name, uri=uri) - case 'mariadb': - db_conn = mariadb.connect( - user=config_database_keys['USER'], - password=config_database_keys['PASSWORD'], - host=config_database_keys['HOST'], - port=config_database_keys['PORT'], - database=database_name - ) + if database_type == 'sqlite': + db_conn = sqlite3.connect(database_name, uri=uri) + elif 'mariadb': + db_conn = mariadb.connect( + user=config_database_keys['USER'], + password=config_database_keys['PASSWORD'], + host=config_database_keys['HOST'], + port=config_database_keys['PORT'], + database=database_name + ) + else: + raise(Exception('Invalid database type %s given' % (database_type))) return db_conn \ No newline at end of file diff --git a/test.py b/test.py index 15a4630..6faa077 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import os import unittest from cpe_search import search_cpes @@ -106,4 +107,5 @@ def test_search_datatables_194(self): if __name__ == '__main__': + os.environ['IS_CPE_SEARCH_TEST'] = 'true' unittest.main()