diff --git a/README.md b/README.md index 0e590b1..2d86b7c 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,14 @@ This is a Python package that integrates Superset and CommCare HQ. Local Development ----------------- -Follow below instructions. +### Preparing CommCare HQ -### Setup env +The 'User configurable reports UI' feature flag must be enabled for the +domain in CommCare HQ, even if the data sources to be imported were +created by Report Builder, not a UCR. + + +### Setting up a dev environment While doing development on top of this integration, it's useful to install this via `pip -e` option so that any changes made get reflected @@ -51,11 +56,12 @@ directly without another `pip install`. Read through the initialization instructions at https://superset.apache.org/docs/installation/installing-superset-from-scratch/#installing-and-initializing-superset. -Create the database. These instructions assume that PostgreSQL is -running on localhost, and that its user is "commcarehq". Adapt -accordingly: +Create a database for Superset, and a database for storing data from +CommCare HQ. Adapt the username and database names to suit your +environment. ```bash -$ createdb -h localhost -p 5432 -U commcarehq superset_meta +$ createdb -h localhost -p 5432 -U postgres superset +$ createdb -h localhost -p 5432 -U postgres superset_hq_data ``` Set the following environment variables: @@ -64,10 +70,17 @@ $ export FLASK_APP=superset $ export SUPERSET_CONFIG_PATH=/path/to/superset_config.py ``` -Initialize the database. Create an administrator. Create default roles +Set this environment variable to allow OAuth 2.0 authentication with +CommCare HQ over insecure HTTP. (DO NOT USE THIS IN PRODUCTION.) +```bash +$ export AUTHLIB_INSECURE_TRANSPORT=1 +``` + +Initialize the databases. Create an administrator. Create default roles and permissions: ```bash $ superset db upgrade +$ superset db upgrade --directory hq_superset/migrations/ $ superset fab create-admin $ superset load_examples # (Optional) $ superset init @@ -78,28 +91,16 @@ You should now be able to run superset using the `superset run` command: ```bash $ superset run -p 8088 --with-threads --reload --debugger ``` -However, OAuth login does not work yet as hq-superset needs a Postgres -database created to store CommCare HQ data. - -### Create a Postgres Database Connection for storing HQ data - -- Create a Postgres database. e.g. - ```bash - $ createdb -h localhost -p 5432 -U commcarehq hq_data - ``` -- Log into Superset as the admin user created in the Superset - installation and initialization. Note that you will need to update - `AUTH_TYPE = AUTH_DB` to log in as admin user. `AUTH_TYPE` should be - otherwise set to `AUTH_OAUTH`. -- Go to 'Data' -> 'Databases' or http://127.0.0.1:8088/databaseview/list/ -- Create a database connection by clicking '+ DATABASE' button at the top. -- The name of the DISPLAY NAME should be 'HQ Data' exactly, as this is - the name by which this codebase refers to the Postgres DB. - -OAuth integration should now be working. You can log in as a CommCare -HQ web user. +You can now log in as a CommCare HQ web user. +In order for CommCare HQ to sync data source changes, you will need to +allow OAuth 2.0 authentication over insecure HTTP. (DO NOT USE THIS IN +PRODUCTION.) Set this environment variable in your CommCare HQ Django +server. (Yes, it's "OAUTHLIB" this time, not "AUTHLIB" as before.) +```bash +$ export OAUTHLIB_INSECURE_TRANSPORT=1 +``` ### Importing UCRs using Redis and Celery @@ -129,6 +130,41 @@ code you want to test will need to be in a module whose dependencies don't include Superset. +### Creating a migration + +You will need to create an Alembic migration for any new SQLAlchemy +models that you add. The Superset CLI should allow you to do this: + +```shell +$ superset db revision --autogenerate -m "Add table for Foo model" +``` + +However, problems with this approach have occurred in the past. You +might have more success by using Alembic directly. You will need to +modify the configuration a little to do this: + +1. Copy the "HQ_DATA" database URI from `superset_config.py`. + +2. Paste it as the value of `sqlalchemy.url` in + `hq_superset/migrations/alembic.ini`. + +3. Edit `env.py` and comment out the following lines: + ``` + hq_data_uri = current_app.config['SQLALCHEMY_BINDS'][HQ_DATA] + decoded_uri = urllib.parse.unquote(hq_data_uri) + config.set_main_option('sqlalchemy.url', decoded_uri) + ``` + +Those changes will allow Alembic to connect to the "HD Data" database +without the need to instantiate Superset's Flask app. You can now +autogenerate your new table with: + +```shell +$ cd hq_superset/migrations/ +$ alembic revision --autogenerate -m "Add table for Foo model" +``` + + Upgrading Superset ------------------ diff --git a/hq_superset/__init__.py b/hq_superset/__init__.py index 54e63f1..6d92ac7 100644 --- a/hq_superset/__init__.py +++ b/hq_superset/__init__.py @@ -8,10 +8,14 @@ def flask_app_mutator(app): # Import the views (which assumes the app is initialized) here # return from superset.extensions import appbuilder + from . import api, hq_domain, views, oauth2_server - from . import hq_domain, views appbuilder.add_view(views.HQDatasourceView, 'Update HQ Datasource', menu_cond=lambda *_: False) appbuilder.add_view(views.SelectDomainView, 'Select a Domain', menu_cond=lambda *_: False) + appbuilder.add_api(api.OAuth) + appbuilder.add_api(api.DataSetChangeAPI) + oauth2_server.config_oauth2(app) + app.before_request_funcs.setdefault(None, []).append( hq_domain.before_request_hook ) @@ -40,4 +44,4 @@ def override_jinja2_template_loader(app): 'images' )) blueprint = Blueprint('Static', __name__, static_url_path='/static/images', static_folder=images_path) - app.register_blueprint(blueprint) \ No newline at end of file + app.register_blueprint(blueprint) diff --git a/hq_superset/api.py b/hq_superset/api.py new file mode 100644 index 0000000..5473c58 --- /dev/null +++ b/hq_superset/api.py @@ -0,0 +1,69 @@ +import json +from http import HTTPStatus + +from flask import jsonify, request +from flask_appbuilder.api import BaseApi, expose +from sqlalchemy.orm.exc import NoResultFound +from superset.superset_typing import FlaskResponse +from superset.views.base import ( + handle_api_exception, + json_error_response, + json_success, +) + +from .models import DataSetChange +from .oauth2_server import authorization, require_oauth + + +class OAuth(BaseApi): + + def __init__(self): + super().__init__() + self.route_base = "/oauth" + + @expose("/token", methods=('POST',)) + def issue_access_token(self): + try: + response = authorization.create_token_response() + except NoResultFound: + return jsonify({"error": "Invalid client"}), 401 + + if response.status_code >= 400: + return response + + data = json.loads(response.data.decode("utf-8")) + return jsonify(data) + + +class DataSetChangeAPI(BaseApi): + """ + Accepts changes to datasets from CommCare HQ data forwarding + """ + + MAX_REQUEST_LENGTH = 10 * 1024 * 1024 # reject JSON requests > 10MB + + def __init__(self): + self.route_base = '/hq_webhook' + self.default_view = 'post_dataset_change' + super().__init__() + + @expose('/change/', methods=('POST',)) + @handle_api_exception + @require_oauth() + def post_dataset_change(self) -> FlaskResponse: + if request.content_length > self.MAX_REQUEST_LENGTH: + return json_error_response( + HTTPStatus.REQUEST_ENTITY_TOO_LARGE.description, + status=HTTPStatus.REQUEST_ENTITY_TOO_LARGE.value, + ) + + try: + request_json = json.loads(request.get_data(as_text=True)) + change = DataSetChange(**request_json) + change.update_dataset() + return json_success('Dataset updated') + except json.JSONDecodeError: + return json_error_response( + 'Invalid JSON syntax', + status=HTTPStatus.BAD_REQUEST.value, + ) diff --git a/hq_superset/const.py b/hq_superset/const.py new file mode 100644 index 0000000..86c7e26 --- /dev/null +++ b/hq_superset/const.py @@ -0,0 +1,2 @@ +# The name of the database for storing data related to CommCare HQ +HQ_DATA = "HQ Data" diff --git a/hq_superset/hq_domain.py b/hq_superset/hq_domain.py index 6fc5a12..de137d4 100644 --- a/hq_superset/hq_domain.py +++ b/hq_superset/hq_domain.py @@ -26,6 +26,8 @@ def after_request_hook(response): "AuthDBView.login", "SelectDomainView.list", "SelectDomainView.select", + "OAuth.issue_access_token", + "DataSetChangeAPI.post_dataset_change", "appbuilder.static", "static", ] @@ -39,7 +41,10 @@ def is_user_admin(): def ensure_domain_selected(): # Check if a hq_domain cookie is set # Ensure necessary roles, permissions and DB schemas are created for the domain - if is_user_admin() or (request.url_rule and request.url_rule.endpoint in DOMAIN_EXCLUDED_VIEWS): + if is_user_admin() or ( + request.url_rule + and request.url_rule.endpoint in DOMAIN_EXCLUDED_VIEWS + ): return hq_domain = request.cookies.get('hq_domain') valid_domains = user_domains() diff --git a/hq_superset/hq_requests.py b/hq_superset/hq_requests.py new file mode 100644 index 0000000..26392c8 --- /dev/null +++ b/hq_superset/hq_requests.py @@ -0,0 +1,30 @@ +import superset +from hq_superset.oauth import get_valid_cchq_oauth_token + + +class HQRequest: + + def __init__(self, url): + self.url = url + + @property + def oauth_token(self): + return get_valid_cchq_oauth_token() + + @property + def commcare_provider(self): + return superset.appbuilder.sm.oauth_remotes["commcare"] + + @property + def api_base_url(self): + return self.commcare_provider.api_base_url + + @property + def absolute_url(self): + return f"{self.api_base_url}{self.url}" + + def get(self): + return self.commcare_provider.get(self.url, token=self.oauth_token) + + def post(self, data): + return self.commcare_provider.post(self.url, data=data, token=self.oauth_token) diff --git a/hq_superset/hq_url.py b/hq_superset/hq_url.py new file mode 100644 index 0000000..aca7e58 --- /dev/null +++ b/hq_superset/hq_url.py @@ -0,0 +1,25 @@ +""" +Functions that return URLs on CommCare HQ +""" + + +def datasource_details(domain, datasource_id): + return f"a/{domain}/api/v0.5/ucr_data_source/{datasource_id}/" + + +def datasource_list(domain): + return f"a/{domain}/api/v0.5/ucr_data_source/" + + +def datasource_export(domain, datasource_id): + return ( + f"a/{domain}/configurable_reports/data_sources/export/{datasource_id}/" + "?format=csv" + ) + + +def datasource_subscribe(domain, datasource_id): + return ( + f"a/{domain}/configurable_reports/data_sources/subscribe/" + f"{datasource_id}/" + ) diff --git a/hq_superset/migrations/README b/hq_superset/migrations/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/hq_superset/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/hq_superset/migrations/alembic.ini b/hq_superset/migrations/alembic.ini new file mode 100644 index 0000000..f01502d --- /dev/null +++ b/hq_superset/migrations/alembic.ini @@ -0,0 +1,115 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = . + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(hour).2d-%%(minute).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +# prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to hq_superset/migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:hq_superset/migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/hq_superset/migrations/env.py b/hq_superset/migrations/env.py new file mode 100644 index 0000000..168bfde --- /dev/null +++ b/hq_superset/migrations/env.py @@ -0,0 +1,77 @@ +import urllib.parse +from logging.config import fileConfig + +from alembic import context +from flask import current_app +from sqlalchemy import engine_from_config, pool + +from hq_superset.const import HQ_DATA +from hq_superset.models import HQClient + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +hq_data_uri = current_app.config['SQLALCHEMY_BINDS'][HQ_DATA] +decoded_uri = urllib.parse.unquote(hq_data_uri) +config.set_main_option('sqlalchemy.url', decoded_uri) + +# add your model's MetaData object here for 'autogenerate' support +target_metadata = HQClient.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/hq_superset/migrations/script.py.mako b/hq_superset/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/hq_superset/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py b/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py new file mode 100644 index 0000000..4c9469a --- /dev/null +++ b/hq_superset/migrations/versions/2024-02-24_23-53_56d0467ff6ff_added_oauth_tables.py @@ -0,0 +1,67 @@ +"""Added OAuth tables + +Revision ID: 56d0467ff6ff +Revises: +Create Date: 2024-02-24 23:53:10.289606 +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '56d0467ff6ff' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'hq_oauth_client', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_id_issued_at', sa.Integer(), nullable=False), + sa.Column('client_secret_expires_at', sa.Integer(), nullable=False), + sa.Column('client_metadata', sa.Text(), nullable=True), + sa.Column('domain', sa.String(length=255), nullable=False), + sa.Column('client_secret', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('domain'), + info={'bind_key': 'HQ Data'}, + ) + op.create_index( + op.f('ix_hq_oauth_client_client_id'), + 'hq_oauth_client', + ['client_id'], + unique=False, + ) + op.create_table( + 'hq_oauth_token', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('access_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('access_token'), + info={'bind_key': 'HQ Data'}, + ) + op.create_index( + op.f('ix_hq_oauth_token_refresh_token'), + 'hq_oauth_token', + ['refresh_token'], + unique=False, + ) + + +def downgrade() -> None: + op.drop_table('hq_oauth_token') + op.drop_index( + op.f('ix_hq_oauth_client_client_id'), table_name='hq_oauth_client' + ) + op.drop_table('hq_oauth_client') diff --git a/hq_superset/models.py b/hq_superset/models.py new file mode 100644 index 0000000..b3e15cc --- /dev/null +++ b/hq_superset/models.py @@ -0,0 +1,115 @@ +import secrets +import string +import time +import uuid +from dataclasses import dataclass +from typing import Any + +from authlib.integrations.sqla_oauth2 import ( + OAuth2ClientMixin, + OAuth2TokenMixin, +) +from cryptography.fernet import MultiFernet +from sqlalchemy import update +from superset import db + +from .const import HQ_DATA +from .utils import cast_data_for_table, get_fernet_keys, get_hq_database + + +@dataclass +class DataSetChange: + data_source_id: str + doc_id: str + data: list[dict[str, Any]] + + def update_dataset(self): + database = get_hq_database() + try: + sqla_table = next(( + table for table in database.tables + if table.table_name == self.data_source_id + )) + except StopIteration: + raise ValueError(f'{self.data_source_id} table not found.') + table = sqla_table.get_sqla_table_object() + + with ( + database.get_sqla_engine_with_context() as engine, + engine.connect() as connection, + connection.begin() # Commit on leaving context + ): + delete_stmt = table.delete().where(table.c.doc_id == self.doc_id) + connection.execute(delete_stmt) + if self.data: + rows = list(cast_data_for_table(self.data, table)) + insert_stmt = table.insert().values(rows) + connection.execute(insert_stmt) + + +class HQClient(db.Model, OAuth2ClientMixin): + __bind_key__ = HQ_DATA + __tablename__ = 'hq_oauth_client' + + domain = db.Column(db.String(255), primary_key=True) + client_secret = db.Column(db.String(255)) # more chars for encryption + + def get_client_secret(self): + keys = get_fernet_keys() + fernet = MultiFernet(keys) + + ciphertext_bytes = self.client_secret.encode('utf-8') + plaintext_bytes = fernet.decrypt(ciphertext_bytes) + return plaintext_bytes.decode('utf-8') + + def set_client_secret(self, plaintext): + keys = get_fernet_keys() + fernet = MultiFernet(keys) + + plaintext_bytes = plaintext.encode('utf-8') + ciphertext_bytes = fernet.encrypt(plaintext_bytes) + self.client_secret = ciphertext_bytes.decode('utf-8') + + def check_client_secret(self, plaintext): + return self.get_client_secret() == plaintext + + def revoke_tokens(self): + revoked_at = int(time.time()) + stmt = ( + update(Token) + .where(Token.client_id == self.client_id) + .where(Token.access_token_revoked_at == 0) + .values(access_token_revoked_at=revoked_at) + ) + db.session.execute(stmt) + db.session.commit() + + @classmethod + def get_by_domain(cls, domain): + return db.session.query(HQClient).filter_by(domain=domain).first() + + @classmethod + def create_domain_client(cls, domain: str): + alphabet = string.ascii_letters + string.digits + client_secret = ''.join(secrets.choice(alphabet) for i in range(64)) + client = HQClient( + domain=domain, + client_id=str(uuid.uuid4()), + ) + client.set_client_secret(client_secret) + client.set_client_metadata({"grant_types": ["client_credentials"]}) + db.session.add(client) + db.session.commit() + return client + + +class Token(db.Model, OAuth2TokenMixin): + __bind_key__ = HQ_DATA + __tablename__ = 'hq_oauth_token' + + id = db.Column(db.Integer, primary_key=True) + + @property + def domain(self): + client = HQClient.get_by_client_id(self.client_id) + return client.domain diff --git a/hq_superset/oauth.py b/hq_superset/oauth.py index e5ef34c..dfad3e3 100644 --- a/hq_superset/oauth.py +++ b/hq_superset/oauth.py @@ -73,7 +73,9 @@ def get_valid_cchq_oauth_token(): # If token hasn't expired yet, return it expires_at = oauth_response.get("expires_at") - if expires_at > int(time.time()): + # TODO: RFC-6749 specifies "expires_in", not "expires_at". + # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + if expires_at is None or expires_at > int(time.time()): return oauth_response # If the token has expired, get a new token using refresh_token diff --git a/hq_superset/oauth2_server.py b/hq_superset/oauth2_server.py new file mode 100644 index 0000000..f5e5c6e --- /dev/null +++ b/hq_superset/oauth2_server.py @@ -0,0 +1,51 @@ +from datetime import timedelta + +from authlib.integrations.flask_oauth2 import ( + AuthorizationServer, + ResourceProtector, +) +from authlib.integrations.sqla_oauth2 import ( + create_bearer_token_validator, + create_query_client_func, + create_revocation_endpoint, +) +from authlib.oauth2.rfc6749 import grants + +from .models import HQClient, Token, db + + +def save_token(token, request): + client = request.client + client.revoke_tokens() + + one_day = 24 * 60 * 60 + token = Token( + client_id=client.client_id, + token_type=token['token_type'], + access_token=token['access_token'], + scope=client.domain, + expires_in=one_day, + ) + db.session.add(token) + db.session.commit() + + +query_client = create_query_client_func(db.session, HQClient) +authorization = AuthorizationServer( + query_client=query_client, + save_token=save_token, +) +require_oauth = ResourceProtector() + + +def config_oauth2(app): + authorization.init_app(app) + authorization.register_grant(grants.ClientCredentialsGrant) + + # support revocation + revocation_cls = create_revocation_endpoint(db.session, Token) + authorization.register_endpoint(revocation_cls) + + # protect resource + bearer_cls = create_bearer_token_validator(db.session, Token) + require_oauth.register_token_validator(bearer_cls()) diff --git a/hq_superset/services.py b/hq_superset/services.py new file mode 100644 index 0000000..c39865a --- /dev/null +++ b/hq_superset/services.py @@ -0,0 +1,205 @@ +import logging +import os +from datetime import datetime +from urllib.parse import urljoin + +import pandas +import sqlalchemy +import superset +from flask import g, request, url_for +from sqlalchemy.dialects import postgresql +from superset import db +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import cache_manager +from superset.sql_parse import Table + +from .hq_requests import HQRequest +from .hq_url import datasource_details, datasource_export, datasource_subscribe +from .models import HQClient +from .utils import ( + CCHQApiException, + convert_to_array, + get_column_dtypes, + get_datasource_file, + get_hq_database, + get_schema_name_for_domain, + parse_date, +) + +logger = logging.getLogger(__name__) + + +def download_datasource(domain, datasource_id): + hq_request = HQRequest(url=datasource_export(domain, datasource_id)) + response = hq_request.get() + if response.status_code != 200: + raise CCHQApiException("Error downloading the UCR export from HQ") + + filename = f"{datasource_id}_{datetime.now()}.zip" + path = os.path.join(superset.config.SHARED_DIR, filename) + with open(path, "wb") as f: + f.write(response.content) + + return path, len(response.content) + + +def get_datasource_defn(domain, datasource_id): + hq_request = HQRequest(url=datasource_details(domain, datasource_id)) + response = hq_request.get() + + if response.status_code != 200: + raise CCHQApiException("Error downloading the UCR definition from HQ") + return response.json() + + +def refresh_hq_datasource( + domain, + datasource_id, + display_name, + file_path, + datasource_defn, + user_id=None, +): + """ + Pulls the data from CommCare HQ and creates/replaces the + corresponding Superset dataset + """ + # See `CsvToDatabaseView.form_post()` in + # https://github.com/apache/superset/blob/master/superset/views/database/views.py + + def dataframe_to_sql(df, replace=False): + """ + Upload Pandas DataFrame ``df`` to ``database``. + """ + database.db_engine_spec.df_to_sql( + database, + csv_table, + df, + to_sql_kwargs={ + "if_exists": "replace" if replace else "append", + "dtype": sql_converters, + "index": False, + }, + ) + + database = get_hq_database() + schema = get_schema_name_for_domain(domain) + csv_table = Table(table=datasource_id, schema=schema) + column_dtypes, date_columns, array_columns = get_column_dtypes( + datasource_defn + ) + converters = { + column_name: convert_to_array for column_name in array_columns + } + sql_converters = { + # Assumes all array values will be of type TEXT + column_name: postgresql.ARRAY(sqlalchemy.types.TEXT) + for column_name in array_columns + } + + try: + with get_datasource_file(file_path) as csv_file: + dataframes = pandas.read_csv( + chunksize=10000, + filepath_or_buffer=csv_file, + encoding="utf-8", + parse_dates=date_columns, + date_parser=parse_date, + keep_default_na=True, + dtype=column_dtypes, + converters=converters, + iterator=True, + low_memory=True, + ) + dataframe_to_sql(next(dataframes), replace=True) + for df in dataframes: + dataframe_to_sql(df, replace=False) + + sqla_table = ( + db.session.query(SqlaTable) + .filter_by( + table_name=datasource_id, + schema=csv_table.schema, + database_id=database.id, + ) + .one_or_none() + ) + if sqla_table: + sqla_table.description = display_name + sqla_table.fetch_metadata() + if not sqla_table: + sqla_table = SqlaTable(table_name=datasource_id) + # Store display name from HQ into description since + # sqla_table.table_name stores datasource_id + sqla_table.description = display_name + sqla_table.database = database + sqla_table.database_id = database.id + if user_id: + user = superset.appbuilder.sm.get_user_by_id(user_id) + else: + user = g.user + sqla_table.owners = [user] + sqla_table.user_id = user.get_id() + sqla_table.schema = csv_table.schema + sqla_table.fetch_metadata() + db.session.add(sqla_table) + db.session.commit() + except Exception as ex: # pylint: disable=broad-except + db.session.rollback() + raise ex + + +def subscribe_to_hq_datasource(domain, datasource_id): + hq_client = HQClient.get_by_domain(domain) + if hq_client is None: + hq_client = HQClient.create_domain_client(domain) + + hq_request = HQRequest(url=datasource_subscribe(domain, datasource_id)) + webhook_url = urljoin( + request.root_url, + url_for('DataSetChangeAPI.post_dataset_change'), + ) + token_url = urljoin(request.root_url, url_for('OAuth.issue_access_token')) + response = hq_request.post({ + 'webhook_url': webhook_url, + 'token_url': token_url, + 'client_id': hq_client.client_id, + 'client_secret': hq_client.get_client_secret(), + }) + if response.status_code == 201: + return + if response.status_code < 500: + logger.error( + f"Failed to subscribe to data source {datasource_id} due to the following issue: {response.data}" + ) + if response.status_code >= 500: + logger.exception( + f"Failed to subscribe to data source {datasource_id} due to a remote server error" + ) + + +class AsyncImportHelper: + def __init__(self, domain, datasource_id): + self.domain = domain + self.datasource_id = datasource_id + + @property + def progress_key(self): + return f"{self.domain}_{self.datasource_id}_import_task_id" + + @property + def task_id(self): + return cache_manager.cache.get(self.progress_key) + + def is_import_in_progress(self): + if not self.task_id: + return False + from celery.result import AsyncResult + res = AsyncResult(self.task_id) + return not res.ready() + + def mark_as_in_progress(self, task_id): + cache_manager.cache.set(self.progress_key, task_id) + + def mark_as_complete(self): + cache_manager.cache.delete(self.progress_key) diff --git a/hq_superset/tasks.py b/hq_superset/tasks.py index cbf7f7e..66ed506 100644 --- a/hq_superset/tasks.py +++ b/hq_superset/tasks.py @@ -1,16 +1,12 @@ -import logging import os from superset.extensions import celery_app -from .utils import AsyncImportHelper - -logger = logging.getLogger(__name__) +from .services import AsyncImportHelper, refresh_hq_datasource @celery_app.task(name='refresh_hq_datasource_task') def refresh_hq_datasource_task(domain, datasource_id, display_name, export_path, datasource_defn, user_id): - from .views import refresh_hq_datasource try: refresh_hq_datasource(domain, datasource_id, display_name, export_path, datasource_defn, user_id) except Exception: diff --git a/hq_superset/templates/hq_datasource_list.html b/hq_superset/templates/hq_datasource_list.html index c80261b..298eaa9 100644 --- a/hq_superset/templates/hq_datasource_list.html +++ b/hq_superset/templates/hq_datasource_list.html @@ -1,5 +1,5 @@