Skip to content

Commit

Permalink
Replace josepy with jwcrypto
Browse files Browse the repository at this point in the history
  • Loading branch information
tonial committed Oct 18, 2024
1 parent 132e20d commit 7374ddc
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 97 deletions.
38 changes: 21 additions & 17 deletions mozilla_django_oidc/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.urls import reverse
from django.utils.encoding import force_bytes, smart_bytes, smart_str
from django.utils.encoding import force_bytes, smart_str
from django.utils.module_loading import import_string
from josepy.b64 import b64decode
from josepy.jwk import JWK
from josepy.jws import JWS, Header
from jwcrypto.common import base64url_decode
from jwcrypto.jwk import JWK
from jwcrypto.jws import JWS, InvalidJWSSignature
from requests.auth import HTTPBasicAuth
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -127,10 +127,11 @@ def update_user(self, user, claims):

def _verify_jws(self, payload, key):
"""Verify the given JWS payload with the given key and return the payload"""
jws = JWS.from_compact(payload)
jws = JWS()
jws.deserialize(smart_str(payload))

try:
alg = jws.signature.combined.alg.name
alg = jws.jose_header["alg"]
except KeyError:
msg = "No alg value found in header"
raise SuspiciousOperation(msg)
Expand All @@ -143,13 +144,17 @@ def _verify_jws(self, payload, key):
raise SuspiciousOperation(msg)

if isinstance(key, str):
# Use smart_bytes here since the key string comes from settings.
jwk = JWK.load(smart_bytes(key))
try:
jwk = JWK.from_pem(force_bytes(key))
except ValueError:
jwk = JWK.from_password(key)
else:
# The key is a json returned from the IDP JWKS endpoint.
jwk = JWK.from_json(key)
jwk = JWK(**key)

if not jws.verify(jwk):
try:
jws.verify(jwk)
except InvalidJWSSignature:
msg = "JWS token verification failed."
raise SuspiciousOperation(msg)

Expand All @@ -167,17 +172,16 @@ def retrieve_matching_jwk(self, token):
jwks = response_jwks.json()

# Compute the current header from the given token to find a match
jws = JWS.from_compact(token)
json_header = jws.signature.protected
header = Header.json_loads(json_header)
jws = JWS()
jws.deserialize(smart_str(token))

key = None
for jwk in jwks["keys"]:
if import_from_settings("OIDC_VERIFY_KID", True) and jwk[
"kid"
] != smart_str(header.kid):
] != smart_str(jws.jose_header["kid"]):
continue
if "alg" in jwk and jwk["alg"] != smart_str(header.alg):
if "alg" in jwk and jwk["alg"] != smart_str(jws.jose_header["alg"]):
continue
key = jwk
if key is None:
Expand All @@ -188,11 +192,11 @@ def get_payload_data(self, token, key):
"""Helper method to get the payload of the JWT token."""
if self.get_settings("OIDC_ALLOW_UNSECURED_JWT", False):
header, payload_data, signature = token.split(b".")
header = json.loads(smart_str(b64decode(header)))
header = json.loads(base64url_decode(smart_str(header)))

# If config allows unsecured JWTs check the header and return the decoded payload
if "alg" in header and header["alg"] == "none":
return b64decode(payload_data)
return base64url_decode(smart_str(payload_data))

# By default fallback to verify JWT signatures
return self._verify_jws(token, key)
Expand Down
35 changes: 2 additions & 33 deletions mozilla_django_oidc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from hashlib import sha256
from urllib.request import parse_http_list, parse_keqv_list

# Make it obvious that these aren't the usual base64 functions
import josepy.b64
from jwcrypto.common import base64url_encode
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

Expand Down Expand Up @@ -55,36 +54,6 @@ def is_authenticated(user):
return user.is_authenticated


def base64_url_encode(bytes_like_obj):
"""Return a URL-Safe, base64 encoded version of bytes_like_obj
Implements base64urlencode as described in
https://datatracker.ietf.org/doc/html/rfc7636#appendix-A
"""

s = josepy.b64.b64encode(bytes_like_obj).decode("ascii") # base64 encode
# the josepy base64 encoder (strips '='s padding) automatically
s = s.replace("+", "-") # 62nd char of encoding
s = s.replace("/", "_") # 63rd char of encoding

return s


def base64_url_decode(string_like_obj):
"""Return the bytes encoded in a URL-Safe, base64 encoded string.
Implements inverse of base64urlencode as described in
https://datatracker.ietf.org/doc/html/rfc7636#appendix-A
This function is not used by the OpenID client; it's just for testing PKCE related functions.
"""
s = string_like_obj

s = s.replace("_", "/") # 63rd char of encoding
s = s.replace("-", "+") # 62nd char of encoding
b = josepy.b64.b64decode(s) # josepy base64 encoder (decodes without '='s padding)

return b


def generate_code_challenge(code_verifier, method):
"""Return a code_challege, which proves knowledge of the code_verifier.
The code challenge is generated according to method which must be one
Expand All @@ -99,7 +68,7 @@ def generate_code_challenge(code_verifier, method):
return code_verifier

elif method == "S256":
return base64_url_encode(sha256(code_verifier.encode("ascii")).digest())
return base64url_encode(sha256(code_verifier.encode("ascii")).digest())

else:
raise ValueError("code challenge method must be 'plain' or 'S256'.")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_disallowed_unsecured_token(self):
)
)

with self.assertRaises(KeyError):
with self.assertRaises(SuspiciousOperation):
self.backend.get_payload_data(token, None)

@override_settings(OIDC_ALLOW_UNSECURED_JWT=True)
Expand Down
46 changes: 0 additions & 46 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from mozilla_django_oidc.utils import (
absolutify,
add_state_and_verifier_and_nonce_to_session,
base64_url_decode,
base64_url_encode,
generate_code_challenge,
import_from_settings,
)
Expand Down Expand Up @@ -52,50 +50,6 @@ def test_absolutify_path_host_injection(self):
self.assertEqual(url, "https://testserver/evil.com/foo/bar")


class Base64URLEncodeTestCase(TestCase):
def test_base64_url_encode(self):
"""
Tests creating a url-safe base64 encoded string from bytes.
Source: https://datatracker.ietf.org/doc/html/rfc7636#appendix-A
"""
data = bytes((3, 236, 255, 224, 193))
encoded = base64_url_encode(data)

# Using base64.b64encode() returns b'A+z/4ME='.
# Our implementation should strip tailing '='s padding.
# and replace '+' with '-' and '/' with '_'.
self.assertEqual(encoded, "A-z_4ME")

# Decoding should return the original data.
decoded = base64_url_decode(encoded)
self.assertEqual(decoded, data)

def test_base64_url_encode_empty_input(self):
"""
Tests creating a url-safe base64 encoded string from an empty bytes instance.
"""
data = bytes()
encoded = base64_url_encode(data)
self.assertEqual(encoded, "")

decoded = base64_url_decode(encoded)
self.assertEqual(decoded, data)

def test_base64_url_encode_double_padding(self):
"""
Test encoding a string whoose base64.b64encode encoding ends with '=='.
"""
data = bytes((3, 236, 255, 224, 193, 222, 22))
encoded = base64_url_encode(data)

# Using base64.b64encode() returns b'A+z/4MHeFg=='.
self.assertEqual(encoded, "A-z_4MHeFg")

# Decoding should return the original data.
decoded = base64_url_decode(encoded)
self.assertEqual(decoded, data)


class PKCECodeVerificationTestCase(TestCase):
def test_generate_code_challenge(self):
"""
Expand Down

0 comments on commit 7374ddc

Please sign in to comment.