From ee2a16d5148a72796f85b46ba5235c7bf70fda53 Mon Sep 17 00:00:00 2001 From: NirTatcher <75395024+NirTatcher@users.noreply.github.com> Date: Thu, 18 Jul 2024 21:13:17 +0300 Subject: [PATCH] SSH SQL Query utility (#1102) Added SSH util function we can connect through SSH and query through, added a test for that. Added documentation about using SQL Mirror support for ActionNetwork which lets us query our data/connect to our data we have in ActionNetwork (all tables read only access). --- docs/action_network.rst | 31 ++++++- parsons/action_network/action_network.py | 1 - parsons/utilities/ssh_utilities.py | 81 +++++++++++++++++++ requirements-dev.txt | 1 + requirements.txt | 2 + setup.py | 1 + .../test_action_network.py | 6 +- test/test_utilities/test_ssh_utilities.py | 66 +++++++++++++++ 8 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 parsons/utilities/ssh_utilities.py create mode 100644 test/test_utilities/test_ssh_utilities.py diff --git a/docs/action_network.rst b/docs/action_network.rst index d9e9bd4b5..a8056c218 100644 --- a/docs/action_network.rst +++ b/docs/action_network.rst @@ -6,7 +6,8 @@ Overview ******** `Action Network `_ is an online tool for storing information -and organizing volunteers and donors. It is used primarily for digital organizing and event mangement. For more information, see `Action Network developer docs `_ +and organizing volunteers and donors. It is used primarily for digital organizing and event mangement. For more information, see `Action Network developer docs `_, `SQL Mirror developer docs `_ + .. note:: Authentication @@ -99,6 +100,34 @@ You can then call various endpoints: # Get a specific wrapper specific_wrapper = an.get_wrapper('wrapper_id') +*********** +SQL Mirror +*********** + +.. code-block:: python + + from parsons.utilities.ssh_utilities import query_through_ssh + + # Define SSH and database parameters + ssh_host = 'ssh.example.com' + ssh_port = 22 + ssh_username = 'user' + ssh_password = 'pass' + db_host = 'db.example.com' + db_port = 5432 + db_name = 'testdb' + db_username = 'dbuser' + db_password = 'dbpass' + query = 'SELECT * FROM table' + + # Use the function to query through SSH + result = query_through_ssh( + ssh_host, ssh_port, ssh_username, ssh_password, + db_host, db_port, db_name, db_username, db_password, query + ) + + # Output the result + print(result) *** API diff --git a/parsons/action_network/action_network.py b/parsons/action_network/action_network.py index cdd68cc70..4145a611f 100644 --- a/parsons/action_network/action_network.py +++ b/parsons/action_network/action_network.py @@ -3,7 +3,6 @@ import re import warnings from typing import Dict, List, Union - from parsons import Table from parsons.utilities import check_env from parsons.utilities.api_connector import APIConnector diff --git a/parsons/utilities/ssh_utilities.py b/parsons/utilities/ssh_utilities.py new file mode 100644 index 000000000..d8466fb91 --- /dev/null +++ b/parsons/utilities/ssh_utilities.py @@ -0,0 +1,81 @@ +import logging +import sshtunnel +import psycopg2 + + +def query_through_ssh( + ssh_host, + ssh_port, + ssh_username, + ssh_password, + db_host, + db_port, + db_name, + db_username, + db_password, + query, +): + """ + `Args:` + ssh_host: + The host for the SSH connection + ssh_port: + The port for the SSH connection + ssh_username: + The username for the SSH connection + ssh_password: + The password for the SSH connection + db_host: + The host for the db connection + db_port: + The port for the db connection + db_name: + The name of the db database + db_username: + The username for the db database + db_password: + The password for the db database + query: + The SQL query to execute + + `Returns:` + A list of records resulting from the query or None if something went wrong + """ + output = None + server = None + con = None + try: + server = sshtunnel.SSHTunnelForwarder( + (ssh_host, int(ssh_port)), + ssh_username=ssh_username, + ssh_password=ssh_password, + remote_bind_address=(db_host, int(db_port)), + ) + server.start() + logging.info("SSH tunnel established successfully.") + + con = psycopg2.connect( + host="localhost", + port=server.local_bind_port, + database=db_name, + user=db_username, + password=db_password, + ) + logging.info("Database connection established successfully.") + + cursor = con.cursor() + cursor.execute(query) + records = cursor.fetchall() + output = records + logging.info(f"Query executed successfully: {records}") + except Exception as e: + logging.error(f"Error during query execution: {e}") + raise e + finally: + if con: + con.close() + logging.info("Database connection closed.") + if server: + server.stop() + logging.info("SSH tunnel closed.") + return output diff --git a/requirements-dev.txt b/requirements-dev.txt index e9e22ab7a..205ae8fbf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ pytest-mock==3.12.0 pytest==8.1.1 requests-mock==1.11.0 testfixtures==8.1.0 + diff --git a/requirements.txt b/requirements.txt index 219c3017c..262b9f958 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,3 +46,5 @@ xmltodict==0.11.0 jinja2>=3.0.2 selenium==3.141.0 us==3.1.1 +sshtunnel==0.4.0 + diff --git a/setup.py b/setup.py index e11b6bd1c..8b300e5fd 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ def main(): "smtp": ["validate-email"], "targetsmart": ["xmltodict"], "twilio": ["twilio"], + "ssh": ["sshtunnel", "psycopg2-binary>=2.9.9", "sqlalchemy >= 1.4.22, != 1.4.33, < 2.0.0"] } extras_require["all"] = sorted({lib for libs in extras_require.values() for lib in libs}) else: diff --git a/test/test_action_network/test_action_network.py b/test/test_action_network/test_action_network.py index 4531fbc50..63e18243c 100644 --- a/test/test_action_network/test_action_network.py +++ b/test/test_action_network/test_action_network.py @@ -1,7 +1,8 @@ import unittest import requests_mock import json -from parsons import Table, ActionNetwork +from parsons import Table +from parsons.action_network import ActionNetwork from test.utils import assert_matching_tables @@ -4293,8 +4294,7 @@ def test_get_wrapper(self, m): self.fake_wrapper, ) - # Unique ID Lists - + # Unique ID Lists @requests_mock.Mocker() def test_get_unique_id_lists(self, m): m.get( diff --git a/test/test_utilities/test_ssh_utilities.py b/test/test_utilities/test_ssh_utilities.py new file mode 100644 index 000000000..5fe5a6a1d --- /dev/null +++ b/test/test_utilities/test_ssh_utilities.py @@ -0,0 +1,66 @@ +import unittest +from unittest.mock import patch, MagicMock +from parsons.utilities.ssh_utilities import query_through_ssh + + +class TestSSHTunnelUtility(unittest.TestCase): + @patch("parsons.utilities.ssh_utilities.sshtunnel.SSHTunnelForwarder") + @patch("parsons.utilities.ssh_utilities.psycopg2.connect") + def test_query_through_ssh(self, mock_connect, mock_tunnel): + # Setup mock for SSHTunnelForwarder + mock_tunnel_instance = MagicMock() + mock_tunnel.return_value = mock_tunnel_instance + mock_tunnel_instance.start.return_value = None + mock_tunnel_instance.stop.return_value = None + mock_tunnel_instance.local_bind_port = 12345 + + # Setup mock for psycopg2.connect + mock_conn_instance = MagicMock() + mock_connect.return_value = mock_conn_instance + mock_cursor = MagicMock() + mock_conn_instance.cursor.return_value = mock_cursor + mock_cursor.fetchall.return_value = [("row1",), ("row2",)] + + # Define the parameters for the test + ssh_host = "ssh.example.com" + ssh_port = 22 + ssh_username = "user" + ssh_password = "pass" + db_host = "db.example.com" + db_port = 5432 + db_name = "testdb" + db_username = "dbuser" + db_password = "dbpass" + query = "SELECT * FROM table" + + # Execute the function under test + result = query_through_ssh( + ssh_host, + ssh_port, + ssh_username, + ssh_password, + db_host, + db_port, + db_name, + db_username, + db_password, + query, + ) + + # Assert that the result is as expected + self.assertEqual(result, [("row1",), ("row2",)]) + mock_tunnel.assert_called_once_with( + (ssh_host, ssh_port), + ssh_username=ssh_username, + ssh_password=ssh_password, + remote_bind_address=(db_host, db_port), + ) + mock_connect.assert_called_once_with( + host="localhost", port=12345, database=db_name, user=db_username, password=db_password + ) + mock_cursor.execute.assert_called_once_with(query) + mock_cursor.fetchall.assert_called_once() + + +if __name__ == "__main__": + unittest.main()