Skip to content

Commit

Permalink
Use symmetric encryption for client secrets
Browse files Browse the repository at this point in the history
  • Loading branch information
kaapstorm committed Feb 20, 2024
1 parent bd7885a commit 191721d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 30 deletions.
38 changes: 29 additions & 9 deletions hq_superset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from typing import Any, Dict, Literal

from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin
from cryptography.fernet import MultiFernet
from flask import current_app
from superset import db
from werkzeug.security import check_password_hash, generate_password_hash

from hq_superset.const import HQ_DATA
from .const import HQ_DATA
from .utils import encoded


@dataclass
Expand All @@ -29,8 +31,29 @@ class HQClient(db.Model, OAuth2ClientMixin):

domain = db.Column(db.String(255), primary_key=True)

def check_client_secret(self, client_secret):
return check_password_hash(self.client_secret, client_secret)
def get_client_secret(self):
keys = [encoded(k, 'ascii') for k in current_app.config['FERNET_KEYS']]
fernet = MultiFernet(keys)

ciphertext_bytes = self.client_secret.encode('utf-8')
plaintext_bytes = fernet.decrypt(ciphertext_bytes)
return plaintext_bytes.decode('utf-8')

def set_client_secret(self, plaintext):
keys = [encoded(k, 'ascii') for k in current_app.config['FERNET_KEYS']]
fernet = MultiFernet(keys)

plaintext_bytes = plaintext.encode('utf-8')
ciphertext_bytes = fernet.encrypt(plaintext_bytes)
self.client_secret = ciphertext_bytes.decode('utf-8')

def check_client_secret(self, plaintext):
keys = [encoded(k, 'ascii') for k in current_app.config['FERNET_KEYS']]
fernet = MultiFernet(keys)

ciphertext_bytes = self.client_secret.encode('utf-8')
plaintext_bytes = plaintext.encode('utf-8')
return fernet.decrypt(ciphertext_bytes) == plaintext_bytes

def revoke_tokens(self):
tokens = db.session.execute(
Expand All @@ -53,18 +76,15 @@ def get_by_client_id(cls, client_id):
def create_domain_client(cls, domain: str):
alphabet = string.ascii_letters + string.digits
client_secret = ''.join(secrets.choice(alphabet) for i in range(64))

client = HQClient(
domain=domain,
client_id=str(uuid.uuid4()),
client_secret=generate_password_hash(client_secret),
)
client.set_client_secret(client_secret)
client.set_client_metadata({"grant_types": ["client_credentials"]})

db.session.add(client)
db.session.commit()

return client.client_id, client_secret
return client


class Token(db.Model):
Expand Down
42 changes: 21 additions & 21 deletions hq_superset/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,27 +151,27 @@ def to_sql(df, replace=False):


def subscribe_to_hq_datasource(domain, datasource_id):
if HQClient.get_by_domain(domain) is None:
hq_request = HQRequest(url=datasource_subscribe(domain, datasource_id))

client_id, client_secret = HQClient.create_domain_client(domain)

response = hq_request.post({
'webhook_url': f'{BASE_URL}/hq_webhook/change/',
'token_url': f'{BASE_URL}/oauth/token',
'client_id': client_id,
'client_secret': client_secret,
})
if response.status_code == 201:
return
if response.status_code < 500:
logger.error(
f"Failed to subscribe to data source {datasource_id} due to the following issue: {response.data}"
)
if response.status_code >= 500:
logger.exception(
f"Failed to subscribe to data source {datasource_id} due to a remote server error"
)
hq_client = HQClient.get_by_domain(domain)
if hq_client is None:
hq_client = HQClient.create_domain_client(domain)

hq_request = HQRequest(url=datasource_subscribe(domain, datasource_id))
response = hq_request.post({
'webhook_url': f'{BASE_URL}/hq_webhook/change/',
'token_url': f'{BASE_URL}/oauth/token',
'client_id': hq_client.client_id,
'client_secret': hq_client.get_client_secret(),
})
if response.status_code == 201:
return
if response.status_code < 500:
logger.error(
f"Failed to subscribe to data source {datasource_id} due to the following issue: {response.data}"
)
if response.status_code >= 500:
logger.exception(
f"Failed to subscribe to data source {datasource_id} due to a remote server error"
)


class AsyncImportHelper:
Expand Down
18 changes: 18 additions & 0 deletions hq_superset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,24 @@ def get_datasource_file(path):
yield zipfile.open(filename)


def encoded(string_maybe, encoding):
"""
Returns ``string_maybe`` encoded with ``encoding``, otherwise
returns it unchanged.
>>> encoded('abc', 'utf-8')
b'abc'
>>> encoded(b'abc', 'ascii')
b'abc'
>>> encoded(123, 'utf-8')
123
"""
if hasattr(string_maybe, 'encode'):
return string_maybe.encode(encoding)
return string_maybe


def convert_to_array(string_array):
"""
Converts the string representation of a list to a list.
Expand Down
14 changes: 14 additions & 0 deletions superset_config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
# $ openssl rand -base64 42
# SECRET_KEY = ...

# [Fernet](https://cryptography.io/en/latest/fernet/) (symmetric
# encryption) is used to encrypt and decrypt client secrets so that the
# same credentials can be used to subscribe to many data sources.
#
# FERNET_KEYS is a list of keys where the first key is the current one,
# the second is the previous one, etc. Encryption uses the first key.
# Decryption is attempted with each key in turn.
#
# To generate a key:
# >>> from cryptography.fernet import Fernet
# >>> Fernet.generate_key()
# Keys can be bytes or strings.
# FERNET_KEYS = [...]

AUTH_TYPE = AUTH_OAUTH # Authenticate with CommCare HQ
# AUTH_TYPE = AUTH_DB # Authenticate with Superset user DB

Expand Down

0 comments on commit 191721d

Please sign in to comment.