diff --git a/README.md b/README.md index 0e590b1..2d86b7c 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,14 @@ This is a Python package that integrates Superset and CommCare HQ. Local Development ----------------- -Follow below instructions. +### Preparing CommCare HQ -### Setup env +The 'User configurable reports UI' feature flag must be enabled for the +domain in CommCare HQ, even if the data sources to be imported were +created by Report Builder, not a UCR. + + +### Setting up a dev environment While doing development on top of this integration, it's useful to install this via `pip -e` option so that any changes made get reflected @@ -51,11 +56,12 @@ directly without another `pip install`. Read through the initialization instructions at https://superset.apache.org/docs/installation/installing-superset-from-scratch/#installing-and-initializing-superset. -Create the database. These instructions assume that PostgreSQL is -running on localhost, and that its user is "commcarehq". Adapt -accordingly: +Create a database for Superset, and a database for storing data from +CommCare HQ. Adapt the username and database names to suit your +environment. ```bash -$ createdb -h localhost -p 5432 -U commcarehq superset_meta +$ createdb -h localhost -p 5432 -U postgres superset +$ createdb -h localhost -p 5432 -U postgres superset_hq_data ``` Set the following environment variables: @@ -64,10 +70,17 @@ $ export FLASK_APP=superset $ export SUPERSET_CONFIG_PATH=/path/to/superset_config.py ``` -Initialize the database. Create an administrator. Create default roles +Set this environment variable to allow OAuth 2.0 authentication with +CommCare HQ over insecure HTTP. (DO NOT USE THIS IN PRODUCTION.) +```bash +$ export AUTHLIB_INSECURE_TRANSPORT=1 +``` + +Initialize the databases. Create an administrator. Create default roles and permissions: ```bash $ superset db upgrade +$ superset db upgrade --directory hq_superset/migrations/ $ superset fab create-admin $ superset load_examples # (Optional) $ superset init @@ -78,28 +91,16 @@ You should now be able to run superset using the `superset run` command: ```bash $ superset run -p 8088 --with-threads --reload --debugger ``` -However, OAuth login does not work yet as hq-superset needs a Postgres -database created to store CommCare HQ data. - -### Create a Postgres Database Connection for storing HQ data - -- Create a Postgres database. e.g. - ```bash - $ createdb -h localhost -p 5432 -U commcarehq hq_data - ``` -- Log into Superset as the admin user created in the Superset - installation and initialization. Note that you will need to update - `AUTH_TYPE = AUTH_DB` to log in as admin user. `AUTH_TYPE` should be - otherwise set to `AUTH_OAUTH`. -- Go to 'Data' -> 'Databases' or http://127.0.0.1:8088/databaseview/list/ -- Create a database connection by clicking '+ DATABASE' button at the top. -- The name of the DISPLAY NAME should be 'HQ Data' exactly, as this is - the name by which this codebase refers to the Postgres DB. - -OAuth integration should now be working. You can log in as a CommCare -HQ web user. +You can now log in as a CommCare HQ web user. +In order for CommCare HQ to sync data source changes, you will need to +allow OAuth 2.0 authentication over insecure HTTP. (DO NOT USE THIS IN +PRODUCTION.) Set this environment variable in your CommCare HQ Django +server. (Yes, it's "OAUTHLIB" this time, not "AUTHLIB" as before.) +```bash +$ export OAUTHLIB_INSECURE_TRANSPORT=1 +``` ### Importing UCRs using Redis and Celery @@ -129,6 +130,41 @@ code you want to test will need to be in a module whose dependencies don't include Superset. +### Creating a migration + +You will need to create an Alembic migration for any new SQLAlchemy +models that you add. The Superset CLI should allow you to do this: + +```shell +$ superset db revision --autogenerate -m "Add table for Foo model" +``` + +However, problems with this approach have occurred in the past. You +might have more success by using Alembic directly. You will need to +modify the configuration a little to do this: + +1. Copy the "HQ_DATA" database URI from `superset_config.py`. + +2. Paste it as the value of `sqlalchemy.url` in + `hq_superset/migrations/alembic.ini`. + +3. Edit `env.py` and comment out the following lines: + ``` + hq_data_uri = current_app.config['SQLALCHEMY_BINDS'][HQ_DATA] + decoded_uri = urllib.parse.unquote(hq_data_uri) + config.set_main_option('sqlalchemy.url', decoded_uri) + ``` + +Those changes will allow Alembic to connect to the "HD Data" database +without the need to instantiate Superset's Flask app. You can now +autogenerate your new table with: + +```shell +$ cd hq_superset/migrations/ +$ alembic revision --autogenerate -m "Add table for Foo model" +``` + + Upgrading Superset ------------------ diff --git a/hq_superset/__init__.py b/hq_superset/__init__.py index 54e63f1..6d92ac7 100644 --- a/hq_superset/__init__.py +++ b/hq_superset/__init__.py @@ -8,10 +8,14 @@ def flask_app_mutator(app): # Import the views (which assumes the app is initialized) here # return from superset.extensions import appbuilder + from . import api, hq_domain, views, oauth2_server - from . import hq_domain, views appbuilder.add_view(views.HQDatasourceView, 'Update HQ Datasource', menu_cond=lambda *_: False) appbuilder.add_view(views.SelectDomainView, 'Select a Domain', menu_cond=lambda *_: False) + appbuilder.add_api(api.OAuth) + appbuilder.add_api(api.DataSetChangeAPI) + oauth2_server.config_oauth2(app) + app.before_request_funcs.setdefault(None, []).append( hq_domain.before_request_hook ) @@ -40,4 +44,4 @@ def override_jinja2_template_loader(app): 'images' )) blueprint = Blueprint('Static', __name__, static_url_path='/static/images', static_folder=images_path) - app.register_blueprint(blueprint) \ No newline at end of file + app.register_blueprint(blueprint) diff --git a/hq_superset/api.py b/hq_superset/api.py new file mode 100644 index 0000000..5473c58 --- /dev/null +++ b/hq_superset/api.py @@ -0,0 +1,69 @@ +import json +from http import HTTPStatus + +from flask import jsonify, request +from flask_appbuilder.api import BaseApi, expose +from sqlalchemy.orm.exc import NoResultFound +from superset.superset_typing import FlaskResponse +from superset.views.base import ( + handle_api_exception, + json_error_response, + json_success, +) + +from .models import DataSetChange +from .oauth2_server import authorization, require_oauth + + +class OAuth(BaseApi): + + def __init__(self): + super().__init__() + self.route_base = "/oauth" + + @expose("/token", methods=('POST',)) + def issue_access_token(self): + try: + response = authorization.create_token_response() + except NoResultFound: + return jsonify({"error": "Invalid client"}), 401 + + if response.status_code >= 400: + return response + + data = json.loads(response.data.decode("utf-8")) + return jsonify(data) + + +class DataSetChangeAPI(BaseApi): + """ + Accepts changes to datasets from CommCare HQ data forwarding + """ + + MAX_REQUEST_LENGTH = 10 * 1024 * 1024 # reject JSON requests > 10MB + + def __init__(self): + self.route_base = '/hq_webhook' + self.default_view = 'post_dataset_change' + super().__init__() + + @expose('/change/', methods=('POST',)) + @handle_api_exception + @require_oauth() + def post_dataset_change(self) -> FlaskResponse: + if request.content_length > self.MAX_REQUEST_LENGTH: + return json_error_response( + HTTPStatus.REQUEST_ENTITY_TOO_LARGE.description, + status=HTTPStatus.REQUEST_ENTITY_TOO_LARGE.value, + ) + + try: + request_json = json.loads(request.get_data(as_text=True)) + change = DataSetChange(**request_json) + change.update_dataset() + return json_success('Dataset updated') + except json.JSONDecodeError: + return json_error_response( + 'Invalid JSON syntax', + status=HTTPStatus.BAD_REQUEST.value, + ) diff --git a/hq_superset/const.py b/hq_superset/const.py new file mode 100644 index 0000000..86c7e26 --- /dev/null +++ b/hq_superset/const.py @@ -0,0 +1,2 @@ +# The name of the database for storing data related to CommCare HQ +HQ_DATA = "HQ Data" diff --git a/hq_superset/hq_domain.py b/hq_superset/hq_domain.py index 6fc5a12..de137d4 100644 --- a/hq_superset/hq_domain.py +++ b/hq_superset/hq_domain.py @@ -26,6 +26,8 @@ def after_request_hook(response): "AuthDBView.login", "SelectDomainView.list", "SelectDomainView.select", + "OAuth.issue_access_token", + "DataSetChangeAPI.post_dataset_change", "appbuilder.static", "static", ] @@ -39,7 +41,10 @@ def is_user_admin(): def ensure_domain_selected(): # Check if a hq_domain cookie is set # Ensure necessary roles, permissions and DB schemas are created for the domain - if is_user_admin() or (request.url_rule and request.url_rule.endpoint in DOMAIN_EXCLUDED_VIEWS): + if is_user_admin() or ( + request.url_rule + and request.url_rule.endpoint in DOMAIN_EXCLUDED_VIEWS + ): return hq_domain = request.cookies.get('hq_domain') valid_domains = user_domains() diff --git a/hq_superset/hq_requests.py b/hq_superset/hq_requests.py new file mode 100644 index 0000000..26392c8 --- /dev/null +++ b/hq_superset/hq_requests.py @@ -0,0 +1,30 @@ +import superset +from hq_superset.oauth import get_valid_cchq_oauth_token + + +class HQRequest: + + def __init__(self, url): + self.url = url + + @property + def oauth_token(self): + return get_valid_cchq_oauth_token() + + @property + def commcare_provider(self): + return superset.appbuilder.sm.oauth_remotes["commcare"] + + @property + def api_base_url(self): + return self.commcare_provider.api_base_url + + @property + def absolute_url(self): + return f"{self.api_base_url}{self.url}" + + def get(self): + return self.commcare_provider.get(self.url, token=self.oauth_token) + + def post(self, data): + return self.commcare_provider.post(self.url, data=data, token=self.oauth_token) diff --git a/hq_superset/hq_url.py b/hq_superset/hq_url.py new file mode 100644 index 0000000..aca7e58 --- /dev/null +++ b/hq_superset/hq_url.py @@ -0,0 +1,25 @@ +""" +Functions that return URLs on CommCare HQ +""" + + +def datasource_details(domain, datasource_id): + return f"a/{domain}/api/v0.5/ucr_data_source/{datasource_id}/" + + +def datasource_list(domain): + return f"a/{domain}/api/v0.5/ucr_data_source/" + + +def datasource_export(domain, datasource_id): + return ( + f"a/{domain}/configurable_reports/data_sources/export/{datasource_id}/" + "?format=csv" + ) + + +def datasource_subscribe(domain, datasource_id): + return ( + f"a/{domain}/configurable_reports/data_sources/subscribe/" + f"{datasource_id}/" + ) diff --git a/hq_superset/migrations/README b/hq_superset/migrations/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/hq_superset/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/hq_superset/migrations/alembic.ini b/hq_superset/migrations/alembic.ini new file mode 100644 index 0000000..f01502d --- /dev/null +++ b/hq_superset/migrations/alembic.ini @@ -0,0 +1,115 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = . + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(hour).2d-%%(minute).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +# prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to hq_superset/migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:hq_superset/migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/hq_superset/migrations/env.py b/hq_superset/migrations/env.py new file mode 100644 index 0000000..168bfde --- /dev/null +++ b/hq_superset/migrations/env.py @@ -0,0 +1,77 @@ +import urllib.parse +from logging.config import fileConfig + +from alembic import context +from flask import current_app +from sqlalchemy import engine_from_config, pool + +from hq_superset.const import HQ_DATA +from hq_superset.models import HQClient + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +hq_data_uri = current_app.config['SQLALCHEMY_BINDS'][HQ_DATA] +decoded_uri = urllib.parse.unquote(hq_data_uri) +config.set_main_option('sqlalchemy.url', decoded_uri) + +# add your model's MetaData object here for 'autogenerate' support +target_metadata = HQClient.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/hq_superset/migrations/script.py.mako b/hq_superset/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/hq_superset/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py b/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py new file mode 100644 index 0000000..4c9469a --- /dev/null +++ b/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py @@ -0,0 +1,67 @@ +"""Added OAuth tables + +Revision ID: 56d0467ff6ff +Revises: +Create Date: 2024-02-24 23:53:10.289606 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '56d0467ff6ff' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'hq_oauth_client', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_id_issued_at', sa.Integer(), nullable=False), + sa.Column('client_secret_expires_at', sa.Integer(), nullable=False), + sa.Column('client_metadata', sa.Text(), nullable=True), + sa.Column('domain', sa.String(length=255), nullable=False), + sa.Column('client_secret', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('domain'), + info={'bind_key': 'HQ Data'}, + ) + op.create_index( + op.f('ix_hq_oauth_client_client_id'), + 'hq_oauth_client', + ['client_id'], + unique=False, + ) + op.create_table( + 'hq_oauth_token', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('access_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('access_token'), + info={'bind_key': 'HQ Data'}, + ) + op.create_index( + op.f('ix_hq_oauth_token_refresh_token'), + 'hq_oauth_token', + ['refresh_token'], + unique=False, + ) + + +def downgrade() -> None: + op.drop_table('hq_oauth_token') + op.drop_index( + op.f('ix_hq_oauth_client_client_id'), table_name='hq_oauth_client' + ) + op.drop_table('hq_oauth_client') diff --git a/hq_superset/models.py b/hq_superset/models.py new file mode 100644 index 0000000..b3e15cc --- /dev/null +++ b/hq_superset/models.py @@ -0,0 +1,115 @@ +import secrets +import string +import time +import uuid +from dataclasses import dataclass +from typing import Any + +from authlib.integrations.sqla_oauth2 import ( + OAuth2ClientMixin, + OAuth2TokenMixin, +) +from cryptography.fernet import MultiFernet +from sqlalchemy import update +from superset import db + +from .const import HQ_DATA +from .utils import cast_data_for_table, get_fernet_keys, get_hq_database + + +@dataclass +class DataSetChange: + data_source_id: str + doc_id: str + data: list[dict[str, Any]] + + def update_dataset(self): + database = get_hq_database() + try: + sqla_table = next(( + table for table in database.tables + if table.table_name == self.data_source_id + )) + except StopIteration: + raise ValueError(f'{self.data_source_id} table not found.') + table = sqla_table.get_sqla_table_object() + + with ( + database.get_sqla_engine_with_context() as engine, + engine.connect() as connection, + connection.begin() # Commit on leaving context + ): + delete_stmt = table.delete().where(table.c.doc_id == self.doc_id) + connection.execute(delete_stmt) + if self.data: + rows = list(cast_data_for_table(self.data, table)) + insert_stmt = table.insert().values(rows) + connection.execute(insert_stmt) + + +class HQClient(db.Model, OAuth2ClientMixin): + __bind_key__ = HQ_DATA + __tablename__ = 'hq_oauth_client' + + domain = db.Column(db.String(255), primary_key=True) + client_secret = db.Column(db.String(255)) # more chars for encryption + + def get_client_secret(self): + keys = get_fernet_keys() + fernet = MultiFernet(keys) + + ciphertext_bytes = self.client_secret.encode('utf-8') + plaintext_bytes = fernet.decrypt(ciphertext_bytes) + return plaintext_bytes.decode('utf-8') + + def set_client_secret(self, plaintext): + keys = get_fernet_keys() + fernet = MultiFernet(keys) + + plaintext_bytes = plaintext.encode('utf-8') + ciphertext_bytes = fernet.encrypt(plaintext_bytes) + self.client_secret = ciphertext_bytes.decode('utf-8') + + def check_client_secret(self, plaintext): + return self.get_client_secret() == plaintext + + def revoke_tokens(self): + revoked_at = int(time.time()) + stmt = ( + update(Token) + .where(Token.client_id == self.client_id) + .where(Token.access_token_revoked_at == 0) + .values(access_token_revoked_at=revoked_at) + ) + db.session.execute(stmt) + db.session.commit() + + @classmethod + def get_by_domain(cls, domain): + return db.session.query(HQClient).filter_by(domain=domain).first() + + @classmethod + def create_domain_client(cls, domain: str): + alphabet = string.ascii_letters + string.digits + client_secret = ''.join(secrets.choice(alphabet) for i in range(64)) + client = HQClient( + domain=domain, + client_id=str(uuid.uuid4()), + ) + client.set_client_secret(client_secret) + client.set_client_metadata({"grant_types": ["client_credentials"]}) + db.session.add(client) + db.session.commit() + return client + + +class Token(db.Model, OAuth2TokenMixin): + __bind_key__ = HQ_DATA + __tablename__ = 'hq_oauth_token' + + id = db.Column(db.Integer, primary_key=True) + + @property + def domain(self): + client = HQClient.get_by_client_id(self.client_id) + return client.domain diff --git a/hq_superset/oauth.py b/hq_superset/oauth.py index e5ef34c..dfad3e3 100644 --- a/hq_superset/oauth.py +++ b/hq_superset/oauth.py @@ -73,7 +73,9 @@ def get_valid_cchq_oauth_token(): # If token hasn't expired yet, return it expires_at = oauth_response.get("expires_at") - if expires_at > int(time.time()): + # TODO: RFC-6749 specifies "expires_in", not "expires_at". + # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + if expires_at is None or expires_at > int(time.time()): return oauth_response # If the token has expired, get a new token using refresh_token diff --git a/hq_superset/oauth2_server.py b/hq_superset/oauth2_server.py new file mode 100644 index 0000000..f5e5c6e --- /dev/null +++ b/hq_superset/oauth2_server.py @@ -0,0 +1,51 @@ +from datetime import timedelta + +from authlib.integrations.flask_oauth2 import ( + AuthorizationServer, + ResourceProtector, +) +from authlib.integrations.sqla_oauth2 import ( + create_bearer_token_validator, + create_query_client_func, + create_revocation_endpoint, +) +from authlib.oauth2.rfc6749 import grants + +from .models import HQClient, Token, db + + +def save_token(token, request): + client = request.client + client.revoke_tokens() + + one_day = 24 * 60 * 60 + token = Token( + client_id=client.client_id, + token_type=token['token_type'], + access_token=token['access_token'], + scope=client.domain, + expires_in=one_day, + ) + db.session.add(token) + db.session.commit() + + +query_client = create_query_client_func(db.session, HQClient) +authorization = AuthorizationServer( + query_client=query_client, + save_token=save_token, +) +require_oauth = ResourceProtector() + + +def config_oauth2(app): + authorization.init_app(app) + authorization.register_grant(grants.ClientCredentialsGrant) + + # support revocation + revocation_cls = create_revocation_endpoint(db.session, Token) + authorization.register_endpoint(revocation_cls) + + # protect resource + bearer_cls = create_bearer_token_validator(db.session, Token) + require_oauth.register_token_validator(bearer_cls()) diff --git a/hq_superset/services.py b/hq_superset/services.py new file mode 100644 index 0000000..c39865a --- /dev/null +++ b/hq_superset/services.py @@ -0,0 +1,205 @@ +import logging +import os +from datetime import datetime +from urllib.parse import urljoin + +import pandas +import sqlalchemy +import superset +from flask import g, request, url_for +from sqlalchemy.dialects import postgresql +from superset import db +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import cache_manager +from superset.sql_parse import Table + +from .hq_requests import HQRequest +from .hq_url import datasource_details, datasource_export, datasource_subscribe +from .models import HQClient +from .utils import ( + CCHQApiException, + convert_to_array, + get_column_dtypes, + get_datasource_file, + get_hq_database, + get_schema_name_for_domain, + parse_date, +) + +logger = logging.getLogger(__name__) + + +def download_datasource(domain, datasource_id): + hq_request = HQRequest(url=datasource_export(domain, datasource_id)) + response = hq_request.get() + if response.status_code != 200: + raise CCHQApiException("Error downloading the UCR export from HQ") + + filename = f"{datasource_id}_{datetime.now()}.zip" + path = os.path.join(superset.config.SHARED_DIR, filename) + with open(path, "wb") as f: + f.write(response.content) + + return path, len(response.content) + + +def get_datasource_defn(domain, datasource_id): + hq_request = HQRequest(url=datasource_details(domain, datasource_id)) + response = hq_request.get() + + if response.status_code != 200: + raise CCHQApiException("Error downloading the UCR definition from HQ") + return response.json() + + +def refresh_hq_datasource( + domain, + datasource_id, + display_name, + file_path, + datasource_defn, + user_id=None, +): + """ + Pulls the data from CommCare HQ and creates/replaces the + corresponding Superset dataset + """ + # See `CsvToDatabaseView.form_post()` in + # https://github.com/apache/superset/blob/master/superset/views/database/views.py + + def dataframe_to_sql(df, replace=False): + """ + Upload Pandas DataFrame ``df`` to ``database``. + """ + database.db_engine_spec.df_to_sql( + database, + csv_table, + df, + to_sql_kwargs={ + "if_exists": "replace" if replace else "append", + "dtype": sql_converters, + "index": False, + }, + ) + + database = get_hq_database() + schema = get_schema_name_for_domain(domain) + csv_table = Table(table=datasource_id, schema=schema) + column_dtypes, date_columns, array_columns = get_column_dtypes( + datasource_defn + ) + converters = { + column_name: convert_to_array for column_name in array_columns + } + sql_converters = { + # Assumes all array values will be of type TEXT + column_name: postgresql.ARRAY(sqlalchemy.types.TEXT) + for column_name in array_columns + } + + try: + with get_datasource_file(file_path) as csv_file: + dataframes = pandas.read_csv( + chunksize=10000, + filepath_or_buffer=csv_file, + encoding="utf-8", + parse_dates=date_columns, + date_parser=parse_date, + keep_default_na=True, + dtype=column_dtypes, + converters=converters, + iterator=True, + low_memory=True, + ) + dataframe_to_sql(next(dataframes), replace=True) + for df in dataframes: + dataframe_to_sql(df, replace=False) + + sqla_table = ( + db.session.query(SqlaTable) + .filter_by( + table_name=datasource_id, + schema=csv_table.schema, + database_id=database.id, + ) + .one_or_none() + ) + if sqla_table: + sqla_table.description = display_name + sqla_table.fetch_metadata() + if not sqla_table: + sqla_table = SqlaTable(table_name=datasource_id) + # Store display name from HQ into description since + # sqla_table.table_name stores datasource_id + sqla_table.description = display_name + sqla_table.database = database + sqla_table.database_id = database.id + if user_id: + user = superset.appbuilder.sm.get_user_by_id(user_id) + else: + user = g.user + sqla_table.owners = [user] + sqla_table.user_id = user.get_id() + sqla_table.schema = csv_table.schema + sqla_table.fetch_metadata() + db.session.add(sqla_table) + db.session.commit() + except Exception as ex: # pylint: disable=broad-except + db.session.rollback() + raise ex + + +def subscribe_to_hq_datasource(domain, datasource_id): + hq_client = HQClient.get_by_domain(domain) + if hq_client is None: + hq_client = HQClient.create_domain_client(domain) + + hq_request = HQRequest(url=datasource_subscribe(domain, datasource_id)) + webhook_url = urljoin( + request.root_url, + url_for('DataSetChangeAPI.post_dataset_change'), + ) + token_url = urljoin(request.root_url, url_for('OAuth.issue_access_token')) + response = hq_request.post({ + 'webhook_url': webhook_url, + 'token_url': token_url, + 'client_id': hq_client.client_id, + 'client_secret': hq_client.get_client_secret(), + }) + if response.status_code == 201: + return + if response.status_code < 500: + logger.error( + f"Failed to subscribe to data source {datasource_id} due to the following issue: {response.data}" + ) + if response.status_code >= 500: + logger.exception( + f"Failed to subscribe to data source {datasource_id} due to a remote server error" + ) + + +class AsyncImportHelper: + def __init__(self, domain, datasource_id): + self.domain = domain + self.datasource_id = datasource_id + + @property + def progress_key(self): + return f"{self.domain}_{self.datasource_id}_import_task_id" + + @property + def task_id(self): + return cache_manager.cache.get(self.progress_key) + + def is_import_in_progress(self): + if not self.task_id: + return False + from celery.result import AsyncResult + res = AsyncResult(self.task_id) + return not res.ready() + + def mark_as_in_progress(self, task_id): + cache_manager.cache.set(self.progress_key, task_id) + + def mark_as_complete(self): + cache_manager.cache.delete(self.progress_key) diff --git a/hq_superset/tasks.py b/hq_superset/tasks.py index cbf7f7e..66ed506 100644 --- a/hq_superset/tasks.py +++ b/hq_superset/tasks.py @@ -1,16 +1,12 @@ -import logging import os from superset.extensions import celery_app -from .utils import AsyncImportHelper - -logger = logging.getLogger(__name__) +from .services import AsyncImportHelper, refresh_hq_datasource @celery_app.task(name='refresh_hq_datasource_task') def refresh_hq_datasource_task(domain, datasource_id, display_name, export_path, datasource_defn, user_id): - from .views import refresh_hq_datasource try: refresh_hq_datasource(domain, datasource_id, display_name, export_path, datasource_defn, user_id) except Exception: diff --git a/hq_superset/templates/hq_datasource_list.html b/hq_superset/templates/hq_datasource_list.html index c80261b..298eaa9 100644 --- a/hq_superset/templates/hq_datasource_list.html +++ b/hq_superset/templates/hq_datasource_list.html @@ -1,5 +1,5 @@
-

Import From CommCareHQ

+

Import from CommCare HQ

diff --git a/hq_superset/tests/test_views.py b/hq_superset/tests/test_views.py index 38499aa..4956772 100644 --- a/hq_superset/tests/test_views.py +++ b/hq_superset/tests/test_views.py @@ -8,10 +8,15 @@ from flask import redirect, session from sqlalchemy.sql import text +from hq_superset.services import download_datasource, refresh_hq_datasource from hq_superset.utils import ( SESSION_USER_DOMAINS_KEY, get_schema_name_for_domain, ) +from hq_superset.views import ( + ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES, + trigger_datasource_refresh, +) from .base_test import HQDBTestCase from .utils import TEST_DATASOURCE @@ -36,6 +41,7 @@ class UserMock(): def get_id(self): return self.user_id + class OAuthMock(): def __init__(self): @@ -106,6 +112,7 @@ def get(self, url, token): a3, 2021-11-22, 2022-01-19, 10, 2022-03-20, some_other_text2 """ + class TestViews(HQDBTestCase): def setUp(self): @@ -197,7 +204,7 @@ def test_non_user_domain_cant_be_selected(self): self.assertTrue('/domain/list' in response.request.path) self.logout(client) - @patch('hq_superset.views.get_valid_cchq_oauth_token', return_value={}) + @patch('hq_superset.oauth.get_valid_cchq_oauth_token', return_value={}) def test_datasource_list(self, *args): def _do_assert(datasources): self.assert_template_used("hq_datasource_list.html") @@ -229,13 +236,9 @@ def test_datasource_upload(self, *args): 'ds1' ) - @patch('hq_superset.views.get_valid_cchq_oauth_token', return_value={}) + @patch('hq_superset.oauth.get_valid_cchq_oauth_token', return_value={}) @patch('hq_superset.views.os.remove') def test_trigger_datasource_refresh(self, *args): - from hq_superset.views import ( - ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES, - trigger_datasource_refresh, - ) domain = 'test1' ds_name = 'ds_name' file_path = '/file_path' @@ -275,26 +278,32 @@ def _test_sync_or_async(ds_size, routing_method, user_id): None ) - @patch('hq_superset.views.get_valid_cchq_oauth_token', return_value={}) - def test_download_datasource(self, *args): - from hq_superset.views import download_datasource + @patch('hq_superset.oauth.get_valid_cchq_oauth_token', return_value={}) + @patch('hq_superset.tasks.subscribe_to_hq_datasource_task.delay') + @patch('hq_superset.hq_requests.HQRequest.get') + def test_download_datasource(self, hq_request_get_mock, subscribe_task_mock, *args): + hq_request_get_mock.return_value = MockResponse( + json_data=TEST_UCR_CSV_V1, + status_code=200, + ) ucr_id = self.oauth_mock.test1_datasources['objects'][0]['id'] - path, size = download_datasource(self.oauth_mock, '_', 'test1', ucr_id) + path, size = download_datasource('test1', ucr_id) + + subscribe_task_mock.assert_called_once_with( + 'test1', + ucr_id, + ) with open(path, 'rb') as f: self.assertEqual(pickle.load(f), TEST_UCR_CSV_V1) self.assertEqual(size, len(pickle.dumps(TEST_UCR_CSV_V1))) os.remove(path) - @patch('hq_superset.views.get_valid_cchq_oauth_token', return_value={}) + @patch('hq_superset.oauth.get_valid_cchq_oauth_token', return_value={}) def test_refresh_hq_datasource(self, *args): - - from hq_superset.views import refresh_hq_datasource - client = self.app.test_client() - ucr_id = self.oauth_mock.test1_datasources['objects'][0]['id'] ds_name = "ds1" - with patch("hq_superset.views.get_datasource_file") as csv_mock, \ - self.app.test_client() as client: + with patch("hq_superset.utils.get_datasource_file") as csv_mock, \ + self.app.test_client() as client: self.login(client) client.get('/domain/select/test1/', follow_redirects=True) @@ -335,3 +344,24 @@ def _test_upload(test_data, expected_output): self.assertEqual(response.status, "302 FOUND") client.get('/hq_datasource/list/', follow_redirects=True) self.assert_context('ucr_id_to_pks', {}) + + # def test_dataset_update(self): + # # The equivalent of something like: + # # + # # $ curl -X POST \ + # # -H "Content-Type: application/json" \ + # # -d '{"action": "upsert", "data_source_id": "abc123", "data": {"doc_id": "abc123"}}' \ + # # http://localhost:8088/hq_webhook/change/ + # + # ucr_id = self.oauth_mock.test1_datasources['objects'][0]['id'] + # ds_name = "ds1" + # with patch("hq_superset.views.get_datasource_file") as csv_mock, \ + # self.app.test_client() as client: + # + # self.login(client) + # + # def test_dataset_insert(self): + # pass + # + # def test_dataset_delete(self): + # pass diff --git a/hq_superset/tests/utils.py b/hq_superset/tests/utils.py index 67089b6..5136d73 100644 --- a/hq_superset/tests/utils.py +++ b/hq_superset/tests/utils.py @@ -1,13 +1,10 @@ from functools import wraps -from sqlalchemy.orm.exc import NoResultFound +from hq_superset.utils import get_hq_database -from hq_superset.utils import HQ_DB_CONNECTION_NAME, get_hq_database -# @pytest.fixture(scope="session", autouse=True) -# def manage_ucr_db(request): -# # setup_ucr_db() -# request.addfinalizer(clear_ucr_db) +class UnitTestingRequired(Exception): + pass def unit_testing_only(fn): @@ -24,18 +21,7 @@ def inner(*args, **kwargs): @unit_testing_only def setup_hq_db(): - import superset - from superset.commands.database.create import CreateDatabaseCommand - try: - get_hq_database() - except NoResultFound: - CreateDatabaseCommand( - { - 'sqlalchemy_uri': superset.app.config.get('HQ_DATA_DB'), - 'engine': 'PostgreSQL', - 'database_name': HQ_DB_CONNECTION_NAME - } - ).run() + get_hq_database() TEST_DATASOURCE = { diff --git a/hq_superset/utils.py b/hq_superset/utils.py index 4cd221a..0a46cab 100644 --- a/hq_superset/utils.py +++ b/hq_superset/utils.py @@ -1,40 +1,32 @@ -import os +import ast from contextlib import contextmanager from datetime import date, datetime +from functools import partial +from typing import Any, Generator from zipfile import ZipFile import pandas import sqlalchemy +from cryptography.fernet import Fernet +from flask import current_app from flask_login import current_user -from superset.extensions import cache_manager +from sqlalchemy.sql import TableClause +from superset.utils.database import get_or_create_db + +from .const import HQ_DATA DOMAIN_PREFIX = "hqdomain_" SESSION_USER_DOMAINS_KEY = "user_hq_domains" SESSION_OAUTH_RESPONSE_KEY = "oauth_response" -HQ_DB_CONNECTION_NAME = "HQ Data" - -ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES = 5_000_000 # ~5MB - - -def get_datasource_export_url(domain, datasource_id): - return f"a/{domain}/configurable_reports/data_sources/export/{datasource_id}/?format=csv" - - -def get_datasource_list_url(domain): - return f"a/{domain}/api/v0.5/ucr_data_source/" -def get_datasource_details_url(domain, datasource_id): - return f"a/{domain}/api/v0.5/ucr_data_source/{datasource_id}/" +class CCHQApiException(Exception): + pass def get_hq_database(): - # Todo; cache to avoid multiple lookups in single request - from superset import db - from superset.models.core import Database - - # Todo; get actual DB once that's implemented - return db.session.query(Database).filter_by(database_name=HQ_DB_CONNECTION_NAME).one() + db_uri = current_app.config['SQLALCHEMY_BINDS'][HQ_DATA] + return get_or_create_db(HQ_DATA, db_uri) def get_schema_name_for_domain(domain): @@ -104,33 +96,6 @@ def parse_date(date_str): return date_str -class AsyncImportHelper: - def __init__(self, domain, datasource_id): - self.domain = domain - self.datasource_id = datasource_id - - @property - def progress_key(self): - return f"{self.domain}_{self.datasource_id}_import_task_id" - - @property - def task_id(self): - return cache_manager.cache.get(self.progress_key) - - def is_import_in_progress(self): - if not self.task_id: - return False - from celery.result import AsyncResult - res = AsyncResult(self.task_id) - return not res.ready() - - def mark_as_in_progress(self, task_id): - cache_manager.cache.set(self.progress_key, task_id) - - def mark_as_complete(self): - cache_manager.cache.delete(self.progress_key) - - class DomainSyncUtil: def __init__(self, security_manager): @@ -187,24 +152,109 @@ def get_datasource_file(path): yield zipfile.open(filename) -def download_datasource(provider, oauth_token, domain, datasource_id): - import superset - datasource_url = get_datasource_export_url(domain, datasource_id) - response = provider.get(datasource_url, token=oauth_token) - if response.status_code != 200: - raise CCHQApiException("Error downloading the UCR export from HQ") +def get_fernet_keys(): + return [ + Fernet(encoded(key, 'ascii')) + for key in current_app.config['FERNET_KEYS'] + ] + + +def encoded(string_maybe, encoding): + """ + Returns ``string_maybe`` encoded with ``encoding``, otherwise + returns it unchanged. - filename = f"{datasource_id}_{datetime.now()}.zip" - path = os.path.join(superset.config.SHARED_DIR, filename) - with open(path, "wb") as f: - f.write(response.content) + >>> encoded('abc', 'utf-8') + b'abc' + >>> encoded(b'abc', 'ascii') + b'abc' + >>> encoded(123, 'utf-8') + 123 + + """ + if hasattr(string_maybe, 'encode'): + return string_maybe.encode(encoding) + return string_maybe + + +def convert_to_array(string_array): + """ + Converts the string representation of a list to a list. + >>> convert_to_array("['hello', 'world']") + ['hello', 'world'] - return path, len(response.content) + >>> convert_to_array("'hello', 'world'") + ['hello', 'world'] + >>> convert_to_array("[None]") + [] + + >>> convert_to_array("hello, world") + [] + """ + + def array_is_falsy(array_values): + return not array_values or array_values == [None] + + try: + array_values = ast.literal_eval(string_array) + except ValueError: + return [] + + if isinstance(array_values, tuple): + array_values = list(array_values) + + # Test for corner cases + if array_is_falsy(array_values): + return [] + + return array_values + + +def js_to_py_datetime(jsdt, preserve_tz=True): + """ + JavaScript UTC datetimes end in "Z". ``datetime.isoformat()`` + doesn't like it. + + >>> jsdt = '2024-02-24T14:01:25.397469Z' + >>> datetime.fromisoformat(jsdt) + Traceback (most recent call last): + ... + ValueError: Invalid isoformat string: '2024-02-24T14:01:25.397469Z' + >>> js_to_py_datetime(jsdt) + datetime.datetime(2024, 2, 24, 14, 1, 25, 397469, tzinfo=datetime.timezone.utc) + + >>> js_to_py_datetime(jsdt, preserve_tz=False) + datetime.datetime(2024, 2, 24, 14, 1, 25, 397469) + + """ + pydt = jsdt.replace('Z', '+00:00') if preserve_tz else jsdt.replace('Z', '') + return datetime.fromisoformat(pydt) + + +def cast_data_for_table( + data: list[dict[str, Any]], + table: TableClause, +) -> Generator[dict[str, Any], None, None]: + """ + Returns ``data`` with values cast in the correct data types for + the columns of ``table``. + """ + cast_functions = { + # 'BIGINT': int, + # 'TEXT': str, + 'TIMESTAMP': partial(js_to_py_datetime, preserve_tz=False), + # TODO: What else? + } -def get_datasource_defn(provider, oauth_token, domain, datasource_id): - url = get_datasource_details_url(domain, datasource_id) - response = provider.get(url, token=oauth_token) - if response.status_code != 200: - raise CCHQApiException("Error downloading the UCR definition from HQ") - return response.json() + column_types = {c.name: str(c.type) for c in table.columns} + for row in data: + cast_row = {} + for column, value in row.items(): + type_name = column_types[column] + if type_name in cast_functions: + cast_func = cast_functions[type_name] + cast_row[column] = cast_func(value) + else: + cast_row[column] = value + yield cast_row diff --git a/hq_superset/views.py b/hq_superset/views.py index ff9e1c0..932b592 100644 --- a/hq_superset/views.py +++ b/hq_superset/views.py @@ -1,14 +1,11 @@ -import ast import logging import os -import pandas as pd -import sqlalchemy +import requests import superset from flask import Response, abort, flash, g, redirect, request, url_for from flask_appbuilder import expose from flask_appbuilder.security.decorators import has_access, permission_name -from sqlalchemy.dialects import postgresql from superset import db from superset.commands.dataset.delete import DeleteDatasetCommand from superset.commands.dataset.exceptions import ( @@ -17,76 +14,83 @@ DatasetNotFoundError, ) from superset.connectors.sqla.models import SqlaTable -from superset.models.core import Database -from superset.sql_parse import Table from superset.views.base import BaseSupersetView from .hq_domain import user_domains -from .oauth import get_valid_cchq_oauth_token -from .tasks import refresh_hq_datasource_task -from .utils import ( - ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES, +from .hq_url import datasource_list +from .hq_requests import HQRequest +from .services import ( AsyncImportHelper, - DomainSyncUtil, download_datasource, - get_column_dtypes, get_datasource_defn, - get_datasource_file, - get_datasource_list_url, - get_hq_database, - get_schema_name_for_domain, - parse_date, + refresh_hq_datasource, + subscribe_to_hq_datasource, ) +from .tasks import refresh_hq_datasource_task +from .utils import DomainSyncUtil, get_hq_database, get_schema_name_for_domain + +ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES = 5_000_000 # ~5MB logger = logging.getLogger(__name__) class HQDatasourceView(BaseSupersetView): - def __init__(self): self.route_base = "/hq_datasource/" self.default_view = "list_hq_datasources" super().__init__() def _ucr_id_to_pks(self): - tables = ( - db.session.query(SqlaTable) - .filter_by( - schema=get_schema_name_for_domain(g.hq_domain), - database_id=get_hq_database().id, - ) + tables = db.session.query(SqlaTable).filter_by( + schema=get_schema_name_for_domain(g.hq_domain), + database_id=get_hq_database().id, ) - return { - table.table_name: table.id - for table in tables.all() - } + return {table.table_name: table.id for table in tables.all()} @expose("/update/", methods=["GET"]) def create_or_update(self, datasource_id): - # Fetches data for a datasource from HQ and creates/updates a superset table + # Fetches data for a datasource from HQ and creates/updates a + # Superset table display_name = request.args.get("name") - res = trigger_datasource_refresh(g.hq_domain, datasource_id, display_name) + res = trigger_datasource_refresh( + g.hq_domain, datasource_id, display_name + ) return res @expose("/list/", methods=["GET"]) def list_hq_datasources(self): - datasource_list_url = get_datasource_list_url(g.hq_domain) - provider = superset.appbuilder.sm.oauth_remotes["commcare"] - oauth_token = get_valid_cchq_oauth_token() - response = provider.get(datasource_list_url, token=oauth_token) + hq_request = HQRequest(url=datasource_list(g.hq_domain)) + try: + response = hq_request.get() + except requests.exceptions.ConnectionError as err: + return Response( + "Unable to connect to CommCare HQ " + f"at {hq_request.absolute_url}", + status=400 + ) + if response.status_code == 403: return Response(status=403) if response.status_code != 200: - url = f"{provider.api_base_url}{datasource_list_url}" - return Response(response=f"There was an error in fetching datasources from CommCareHQ at {url}", status=400) + try: + msg = response.json()['error'] + except: # pylint: disable=E722 + msg = '' + return Response( + "There was an error in fetching datasources from CommCare HQ " + f"at {hq_request.absolute_url}: {response.status_code} {msg}", + status=400 + ) hq_datasources = response.json() for ds in hq_datasources['objects']: - ds['is_import_in_progress'] = AsyncImportHelper(g.hq_domain, ds['id']).is_import_in_progress() + ds['is_import_in_progress'] = AsyncImportHelper( + g.hq_domain, ds['id'] + ).is_import_in_progress() return self.render_template( "hq_datasource_list.html", hq_datasources=hq_datasources, ucr_id_to_pks=self._ucr_id_to_pks(), - hq_base_url=provider.api_base_url + hq_base_url=hq_request.api_base_url ) @expose("/delete/", methods=["GET"]) @@ -104,182 +108,72 @@ def delete(self, datasource_pk): str(ex), exc_info=True, ) - return abort(description=str(ex)) + return abort(400, description=str(ex)) return redirect("/tablemodelview/list/") -class CCHQApiException(Exception): - pass - - -def convert_to_array(string_array): - """ - Converts the string representation of a list to a list. - >>> convert_to_array("['hello', 'world']") - ['hello', 'world'] - - >>> convert_to_array("'hello', 'world'") - ['hello', 'world'] - - >>> convert_to_array("[None]") - [] - - >>> convert_to_array("hello, world") - [] - """ - - def array_is_falsy(array_values): - return not array_values or array_values == [None] - - try: - array_values = ast.literal_eval(string_array) - except ValueError: - return [] - - if isinstance(array_values, tuple): - array_values = list(array_values) - - # Test for corner cases - if array_is_falsy(array_values): - return [] - - return array_values - - def trigger_datasource_refresh(domain, datasource_id, display_name): if AsyncImportHelper(domain, datasource_id).is_import_in_progress(): - flash("The datasource is already being imported in the background. " - "Please wait for it to finish before retrying", - "warning") + flash( + "The datasource is already being imported in the background. " + "Please wait for it to finish before retrying.", + "warning", + ) return redirect("/tablemodelview/list/") - provider = superset.appbuilder.sm.oauth_remotes["commcare"] - token = get_valid_cchq_oauth_token() - path, size = download_datasource(provider, token, domain, datasource_id) - datasource_defn = get_datasource_defn(provider, token, domain, datasource_id) + subscribe_to_hq_datasource(domain, datasource_id) + path, size = download_datasource(domain, datasource_id) + datasource_defn = get_datasource_defn(domain, datasource_id) if size < ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES: - response = refresh_hq_datasource( + refresh_hq_datasource( domain, datasource_id, display_name, path, datasource_defn, None - ) + ) os.remove(path) - return response + return redirect("/tablemodelview/list/") else: limit_in_mb = int(ASYNC_DATASOURCE_IMPORT_LIMIT_IN_BYTES / 1000000) - flash("The datasource is being refreshed in the background as it is" - f" larger than {limit_in_mb} MB. This may take a while, please wait for it to finish", - "info") - return queue_refresh_task(domain, datasource_id, display_name, path, datasource_defn, g.user.get_id()) + flash( + "The datasource is being refreshed in the background as it is " + f"larger than {limit_in_mb} MB. This may take a while, please " + "wait for it to finish.", + "info", + ) + return queue_refresh_task( + domain, + datasource_id, + display_name, + path, + datasource_defn, + g.user.get_id(), + ) -def queue_refresh_task(domain, datasource_id, display_name, export_path, datasource_defn, user_id): +def queue_refresh_task( + domain, + datasource_id, + display_name, + export_path, + datasource_defn, + user_id, +): task_id = refresh_hq_datasource_task.delay( - domain, datasource_id, display_name, export_path, datasource_defn, g.user.get_id() + domain, + datasource_id, + display_name, + export_path, + datasource_defn, + g.user.get_id(), ).task_id AsyncImportHelper(domain, datasource_id).mark_as_in_progress(task_id) return redirect("/tablemodelview/list/") -def refresh_hq_datasource(domain, datasource_id, display_name, file_path, datasource_defn, user_id=None): - """ - Pulls the data from CommCare HQ and creates/replaces the - corresponding Superset dataset - """ - database = get_hq_database() - schema = get_schema_name_for_domain(domain) - csv_table = Table(table=datasource_id, schema=schema) - column_dtypes, date_columns, array_columns = get_column_dtypes(datasource_defn) - - converters = {column_name: convert_to_array for column_name in array_columns} - # TODO: can we assume all array values will be of type TEXT? - sqlconverters = {column_name: postgresql.ARRAY(sqlalchemy.types.TEXT) for column_name in array_columns} - - def to_sql(df, replace=False): - database.db_engine_spec.df_to_sql( - database, - csv_table, - df, - to_sql_kwargs={ - "if_exists": "replace" if replace else "append", - "dtype": sqlconverters, - }, - ) - - try: - with get_datasource_file(file_path) as csv_file: - - _iter = pd.read_csv( - chunksize=10000, - filepath_or_buffer=csv_file, - encoding="utf-8", - parse_dates=date_columns, - date_parser=parse_date, - keep_default_na=True, - dtype=column_dtypes, - converters=converters, - iterator=True, - low_memory=True, - ) - - to_sql(next(_iter), replace=True) - - for df in _iter: - to_sql(df, replace=False) - - - # Connect table to the database that should be used for exploration. - # E.g. if hive was used to upload a csv, presto will be a better option - # to explore the table. - expore_database = database - explore_database_id = database.explore_database_id - if explore_database_id: - expore_database = ( - db.session.query(Database) - .filter_by(id=explore_database_id) - .one_or_none() - or database - ) - - sqla_table = ( - db.session.query(SqlaTable) - .filter_by( - table_name=datasource_id, - schema=csv_table.schema, - database_id=expore_database.id, - ) - .one_or_none() - ) - if sqla_table: - sqla_table.description = display_name - sqla_table.fetch_metadata() - if not sqla_table: - sqla_table = SqlaTable(table_name=datasource_id) - # Store display name from HQ into description since - # sqla_table.table_name stores datasource_id - sqla_table.description = display_name - sqla_table.database = expore_database - sqla_table.database_id = database.id - if user_id: - user = superset.appbuilder.sm.get_user_by_id(user_id) - else: - user = g.user - sqla_table.owners = [user] - sqla_table.user_id = user.get_id() - sqla_table.schema = csv_table.schema - sqla_table.fetch_metadata() - db.session.add(sqla_table) - db.session.commit() - except Exception as ex: # pylint: disable=broad-except - db.session.rollback() - raise ex - - # superset.appbuilder.sm.add_permission_role(role, sqla_table.get_perm()) - return redirect("/tablemodelview/list/") - - class SelectDomainView(BaseSupersetView): """ - Select a Domain view, all roles that have 'profile' access on 'core.Superset' view can access this + Select a Domain view, all roles that have 'profile' access on + 'core.Superset' view can access this """ + # re-use core.Superset view's permission name class_permission_name = "Superset" @@ -295,16 +189,21 @@ def list(self): return self.render_template( 'select_domain.html', next=request.args.get('next'), - domains=user_domains() + domains=user_domains(), ) @expose('/select//', methods=['GET']) @has_access @permission_name("profile") def select(self, hq_domain): - response = redirect(request.args.get('next') or self.appbuilder.get_url_for_index) + response = redirect( + request.args.get('next') or self.appbuilder.get_url_for_index + ) if hq_domain not in user_domains(): - flash('Please select a valid domain to access this page.', 'warning') + flash( + 'Please select a valid domain to access this page.', + 'warning', + ) return redirect(url_for('SelectDomainView.list', next=request.url)) response.set_cookie('hq_domain', hq_domain) DomainSyncUtil(superset.appbuilder.sm).sync_domain_role(hq_domain) diff --git a/superset_config.example.py b/superset_config.example.py index a8b6ebd..eadc6f0 100644 --- a/superset_config.example.py +++ b/superset_config.example.py @@ -13,17 +13,34 @@ from sentry_sdk.integrations.flask import FlaskIntegration from hq_superset import flask_app_mutator, oauth - +from hq_superset.const import HQ_DATA # Use a tool to generate a sufficiently random string, e.g. # $ openssl rand -base64 42 # SECRET_KEY = ... +# [Fernet](https://cryptography.io/en/latest/fernet/) (symmetric +# encryption) is used to encrypt and decrypt client secrets so that the +# same credentials can be used to subscribe to many data sources. +# +# FERNET_KEYS is a list of keys where the first key is the current one, +# the second is the previous one, etc. Encryption uses the first key. +# Decryption is attempted with each key in turn. +# +# To generate a key: +# >>> from cryptography.fernet import Fernet +# >>> Fernet.generate_key() +# Keys can be bytes or strings. +# FERNET_KEYS = [...] + AUTH_TYPE = AUTH_OAUTH # Authenticate with CommCare HQ # AUTH_TYPE = AUTH_DB # Authenticate with Superset user DB -# Override this to reflect your local Postgres DB -SQLALCHEMY_DATABASE_URI = 'postgresql://postgres:postgres@localhost:5433/superset_meta' +# Override these for your databases for Superset and HQ Data +SQLALCHEMY_DATABASE_URI = 'postgresql://postgres:postgres@localhost:5432/superset' +SQLALCHEMY_BINDS = { + HQ_DATA: 'postgresql://postgres:postgres@localhost:5432/superset_hq_data' +} # Populate with oauth credentials from your local CommCareHQ OAUTH_PROVIDERS = [ @@ -86,6 +103,7 @@ class CeleryConfig: + accept_content = ['pickle'] broker_url = _REDIS_URL imports = ( 'superset.sql_lab', @@ -122,6 +140,10 @@ class CeleryConfig: 'pt': {'flag':'pt', 'name':'Portuguese'} } +OAUTH2_TOKEN_EXPIRES_IN = { + 'client_credentials': 86400, +} + # CommCare Analytics extensions FLASK_APP_MUTATOR = flask_app_mutator CUSTOM_SECURITY_MANAGER = oauth.CommCareSecurityManager