diff --git a/CHANGES.rst b/CHANGES.rst index 66a72174..a86e5811 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,11 @@ All notable changes to this project will be documented in this file. The format is based on `Keep a Changelog `_, and this project adheres to `Semantic Versioning `_. +Added +***** + +- OIDC `prompt=create` support. :issue:`185` :pr:`164` + Fixed ***** @@ -13,6 +18,7 @@ Fixed Added ***** + - ``THEME`` can be a relative path [0.0.39] - 2023-12-15 diff --git a/canaille/oidc/endpoints.py b/canaille/oidc/endpoints.py index 2796cac8..e09dc1a1 100644 --- a/canaille/oidc/endpoints.py +++ b/canaille/oidc/endpoints.py @@ -34,6 +34,7 @@ from .oauth import require_oauth from .oauth import RevocationEndpoint from .utils import SCOPE_DETAILS +from .well_known import openid_configuration bp = Blueprint("endpoints", __name__, url_prefix="/oauth") @@ -54,6 +55,23 @@ def authorize(): if not client: abort(400, "Invalid client.") + # https://openid.net/specs/openid-connect-prompt-create-1_0.html#name-authorization-request + # If the OpenID Provider receives a prompt value that it does + # not support (not declared in the prompt_values_supported + # metadata field) the OP SHOULD respond with an HTTP 400 (Bad + # Request) status code and an error value of invalid_request. + # It is RECOMMENDED that the OP return an error_description + # value identifying the invalid parameter value. + if ( + request.args.get("prompt") + and request.args["prompt"] + not in openid_configuration()["prompt_values_supported"] + ): + return { + "error": "invalid_request", + "error_description": f"prompt '{request.args['prompt'] }' value is not supported", + }, 400 + user = current_user() requested_scopes = request.args.get("scope", "").split(" ") allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ") @@ -65,6 +83,10 @@ def authorize(): return jsonify({"error": "login_required"}) session["redirect-after-login"] = request.url + + if request.args.get("prompt") == "create": + return redirect(url_for("core.account.join")) + return redirect(url_for("core.auth.login")) if not user.can_use_oidc: diff --git a/canaille/oidc/well_known.py b/canaille/oidc/well_known.py index 72704341..df03fd23 100644 --- a/canaille/oidc/well_known.py +++ b/canaille/oidc/well_known.py @@ -1,4 +1,5 @@ from flask import Blueprint +from flask import current_app from flask import g from flask import jsonify from flask import request @@ -76,7 +77,8 @@ def openid_configuration(): ], "subject_types_supported": ["pairwise", "public"], "id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"], - "prompt_values_supported": ["none"], + "prompt_values_supported": ["none"] + + (["create"] if current_app.config.get("ENABLE_REGISTRATION") else []), } diff --git a/doc/specifications.rst b/doc/specifications.rst index 8b5c6851..c61e5a61 100644 --- a/doc/specifications.rst +++ b/doc/specifications.rst @@ -40,7 +40,7 @@ OpenID Connect - ❌ `OpenID Connect Back Channel Logout `_ - ❌ `OpenID Connect Back Channel Authentication Flow `_ - ❌ `OpenID Connect Core Error Code unmet_authentication_requirements `_ -- ❌ `Initiating User Registration via OpenID Connect 1.0 `_ +- ✅ `Initiating User Registration via OpenID Connect 1.0 `_ Comparison with other providers =============================== diff --git a/tests/core/test_registration.py b/tests/core/test_registration.py index c7f39305..bd5080d3 100644 --- a/tests/core/test_registration.py +++ b/tests/core/test_registration.py @@ -71,6 +71,10 @@ def test_registration_with_email_validation(testclient, backend, smtpd): res.form["family_name"] = "newuser" res = res.form.submit() + assert res.flashes == [ + ("success", "Your account has been created successfully."), + ] + user = models.User.get() assert user user.delete() diff --git a/tests/oidc/conftest.py b/tests/oidc/conftest.py index ed5d3fc8..969a4fbb 100644 --- a/tests/oidc/conftest.py +++ b/tests/oidc/conftest.py @@ -41,7 +41,7 @@ def configuration(configuration, keypair): @pytest.fixture -def client(testclient, other_client, backend): +def client(testclient, trusted_client, backend): c = models.Client( client_id=gen_salt(24), client_name="Some client", @@ -69,7 +69,7 @@ def client(testclient, other_client, backend): token_endpoint_auth_method="client_secret_basic", post_logout_redirect_uris=["https://mydomain.tld/disconnected"], ) - c.audience = [c, other_client] + c.audience = [c, trusted_client] c.save() yield c @@ -77,7 +77,7 @@ def client(testclient, other_client, backend): @pytest.fixture -def other_client(testclient, backend): +def trusted_client(testclient, backend): c = models.Client( client_id=gen_salt(24), client_name="Some other client", @@ -104,6 +104,7 @@ def other_client(testclient, backend): jwks_uri="https://myotherdomain.tld/jwk", token_endpoint_auth_method="client_secret_basic", post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"], + preconsent=True, ) c.audience = [c] c.save() diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py index 8af21eb5..dc2a5cd3 100644 --- a/tests/oidc/test_authorization_code_flow.py +++ b/tests/oidc/test_authorization_code_flow.py @@ -13,7 +13,7 @@ def test_authorization_code_flow( - testclient, logged_user, client, keypair, other_client + testclient, logged_user, client, keypair, trusted_client ): assert not models.Consent.query() @@ -81,13 +81,13 @@ def test_authorization_code_flow( claims = jwt.decode(access_token, keypair[1]) assert claims["sub"] == logged_user.user_name assert claims["name"] == logged_user.formatted_name - assert claims["aud"] == [client.client_id, other_client.client_id] + assert claims["aud"] == [client.client_id, trusted_client.client_id] id_token = res.json["id_token"] claims = jwt.decode(id_token, keypair[1]) assert claims["sub"] == logged_user.user_name assert claims["name"] == logged_user.formatted_name - assert claims["aud"] == [client.client_id, other_client.client_id] + assert claims["aud"] == [client.client_id, trusted_client.client_id] res = testclient.get( "/oauth/userinfo", @@ -114,7 +114,7 @@ def test_invalid_client(testclient, logged_user, keypair): def test_authorization_code_flow_with_redirect_uri( - testclient, logged_user, client, keypair, other_client + testclient, logged_user, client, keypair, trusted_client ): assert not models.Consent.query() @@ -161,7 +161,7 @@ def test_authorization_code_flow_with_redirect_uri( def test_authorization_code_flow_preconsented( - testclient, logged_user, client, keypair, other_client + testclient, logged_user, client, keypair, trusted_client ): assert not models.Consent.query() @@ -209,7 +209,7 @@ def test_authorization_code_flow_preconsented( claims = jwt.decode(id_token, keypair[1]) assert logged_user.user_name == claims["sub"] assert logged_user.formatted_name == claims["name"] - assert [client.client_id, other_client.client_id] == claims["aud"] + assert [client.client_id, trusted_client.client_id] == claims["aud"] res = testclient.get( "/oauth/userinfo", @@ -584,7 +584,7 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc def test_authorization_code_flow_but_user_cannot_use_oidc( - testclient, user, client, keypair, other_client + testclient, user, client, keypair, trusted_client ): testclient.app.config["ACL"]["DEFAULT"]["PERMISSIONS"] = [] user.reload() @@ -645,16 +645,17 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client): def test_authorization_code_request_scope_too_large( - testclient, logged_user, keypair, other_client + testclient, logged_user, keypair, client ): assert not models.Consent.query() - assert "email" not in other_client.scope + client.scope = ["openid", "profile", "groups"] + client.save() res = testclient.get( "/oauth/authorize", params=dict( response_type="code", - client_id=other_client.client_id, + client_id=client.client_id, scope="openid profile email", nonce="somenonce", ), @@ -671,7 +672,7 @@ def test_authorization_code_request_scope_too_large( "profile", } - consents = models.Consent.query(client=other_client, subject=logged_user) + consents = models.Consent.query(client=client, subject=logged_user) assert set(consents[0].scope) == { "openid", "profile", @@ -683,15 +684,15 @@ def test_authorization_code_request_scope_too_large( grant_type="authorization_code", code=code, scope="openid profile email groups address phone", - redirect_uri=other_client.redirect_uris[0], + redirect_uri=client.redirect_uris[0], ), - headers={"Authorization": f"Basic {client_credentials(other_client)}"}, + headers={"Authorization": f"Basic {client_credentials(client)}"}, status=200, ) access_token = res.json["access_token"] token = models.Token.get(access_token=access_token) - assert token.client == other_client + assert token.client == client assert token.subject == logged_user assert set(token.scope) == { "openid", diff --git a/tests/oidc/test_authorization_prompt.py b/tests/oidc/test_authorization_prompt.py index ac337843..b0f53652 100644 --- a/tests/oidc/test_authorization_prompt.py +++ b/tests/oidc/test_authorization_prompt.py @@ -2,11 +2,14 @@ Tests the behavior of Canaille depending on the OIDC 'prompt' parameter. https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint """ +import datetime import uuid from urllib.parse import parse_qs from urllib.parse import urlsplit from canaille.app import models +from canaille.core.account import RegistrationPayload +from flask import url_for def test_prompt_none(testclient, logged_user, client): @@ -98,3 +101,125 @@ def test_prompt_no_consent(testclient, logged_user, client): status=200, ) assert "consent_required" == res.json.get("error") + + +def test_prompt_create_logged(testclient, logged_user, client): + """ + If prompt=create and user is already logged in, + then go straight to the consent page. + """ + testclient.app.config["ENABLE_REGISTRATION"] = True + + consent = models.Consent( + consent_id=str(uuid.uuid4()), + client=client, + subject=logged_user, + scope=["openid", "profile"], + ) + consent.save() + + res = testclient.get( + "/oauth/authorize", + params=dict( + response_type="code", + client_id=client.client_id, + scope="openid profile", + nonce="somenonce", + prompt="create", + ), + status=302, + ) + assert res.location.startswith(client.redirect_uris[0]) + + consent.delete() + + +def test_prompt_create_registration_disabled(testclient, trusted_client, smtpd): + """ + If prompt=create but Canaille registration is disabled, + an error response should be returned. + + If the OpenID Provider receives a prompt value that it does + not support (not declared in the prompt_values_supported + metadata field) the OP SHOULD respond with an HTTP 400 (Bad + Request) status code and an error value of invalid_request. + It is RECOMMENDED that the OP return an error_description + value identifying the invalid parameter value. + """ + res = testclient.get( + "/oauth/authorize", + params=dict( + response_type="code", + client_id=trusted_client.client_id, + scope="openid profile", + nonce="somenonce", + prompt="create", + ), + status=400, + ) + assert res.json == { + "error": "invalid_request", + "error_description": "prompt 'create' value is not supported", + } + + +def test_prompt_create_not_logged(testclient, trusted_client, smtpd): + """ + If prompt=create and user is not logged in, + then display the registration form. + Check that the user is correctly redirected to + the client page after the registration process. + """ + testclient.app.config["ENABLE_REGISTRATION"] = True + + res = testclient.get( + "/oauth/authorize", + params=dict( + response_type="code", + client_id=trusted_client.client_id, + scope="openid profile", + nonce="somenonce", + prompt="create", + ), + ) + + # Display the registration form + res = res.follow() + res.form["email"] = "foo@bar.com" + res = res.form.submit() + + # Checks the registration mail is sent + assert len(smtpd.messages) == 1 + + # Simulate a click on the validation link in the mail + payload = RegistrationPayload( + creation_date_isoformat=datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + user_name="", + user_name_editable=True, + email="foo@bar.com", + groups=[], + ) + registration_url = url_for( + "core.account.registration", + data=payload.b64(), + hash=payload.build_hash(), + _external=True, + ) + + # Fill the user creation form + res = testclient.get(registration_url) + res.form["user_name"] = "newuser" + res.form["password1"] = "password" + res.form["password2"] = "password" + res.form["family_name"] = "newuser" + res = res.form.submit() + + assert res.flashes == [ + ("success", "Your account has been created successfully."), + ] + + # Return to the client + res = res.follow() + assert res.location.startswith(trusted_client.redirect_uris[0]) diff --git a/tests/oidc/test_client_admin.py b/tests/oidc/test_client_admin.py index 6b42ff42..cfcef631 100644 --- a/tests/oidc/test_client_admin.py +++ b/tests/oidc/test_client_admin.py @@ -21,7 +21,7 @@ def test_client_list(testclient, client, logged_admin): res.mustcontain(client.client_name) -def test_client_list_pagination(testclient, logged_admin, client, other_client): +def test_client_list_pagination(testclient, logged_admin, client, trusted_client): res = testclient.get("/admin/client") res.mustcontain("2 items") clients = [] @@ -67,18 +67,18 @@ def test_client_list_bad_pages(testclient, logged_admin): ) -def test_client_list_search(testclient, logged_admin, client, other_client): +def test_client_list_search(testclient, logged_admin, client, trusted_client): res = testclient.get("/admin/client") res.mustcontain("2 items") res.mustcontain(client.client_name) - res.mustcontain(other_client.client_name) + res.mustcontain(trusted_client.client_name) form = res.forms["search"] form["query"] = "other" res = form.submit() res.mustcontain("1 item") - res.mustcontain(other_client.client_name) + res.mustcontain(trusted_client.client_name) res.mustcontain(no=client.client_name) @@ -144,7 +144,7 @@ def test_add_missing_fields(testclient, logged_admin): ) in res.flashes -def test_client_edit(testclient, client, logged_admin, other_client): +def test_client_edit(testclient, client, logged_admin, trusted_client): res = testclient.get("/admin/client/edit/" + client.client_id) data = { "client_name": "foobar", @@ -162,7 +162,7 @@ def test_client_edit(testclient, client, logged_admin, other_client): "software_version": "1", "jwk": "jwk", "jwks_uri": "https://foo.bar/jwks.json", - "audience": [client.id, other_client.id], + "audience": [client.id, trusted_client.id], "preconsent": True, "post_logout_redirect_uris-0": "https://foo.bar/disconnected", } @@ -196,12 +196,12 @@ def test_client_edit(testclient, client, logged_admin, other_client): assert client.software_version == "1" assert client.jwk == "jwk" assert client.jwks_uri == "https://foo.bar/jwks.json" - assert client.audience == [client, other_client] + assert client.audience == [client, trusted_client] assert not client.preconsent assert client.post_logout_redirect_uris == ["https://foo.bar/disconnected"] -def test_client_edit_missing_fields(testclient, client, logged_admin, other_client): +def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_client): res = testclient.get("/admin/client/edit/" + client.client_id) res.forms["clientaddform"]["client_name"] = "" res = res.forms["clientaddform"].submit(name="action", value="edit") @@ -255,7 +255,7 @@ def test_client_delete_invalid_client(testclient, logged_admin, client): ) -def test_client_edit_preauth(testclient, client, logged_admin, other_client): +def test_client_edit_preauth(testclient, client, logged_admin, trusted_client): assert not client.preconsent res = testclient.get("/admin/client/edit/" + client.client_id) @@ -275,7 +275,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, other_client): assert not client.preconsent -def test_client_edit_invalid_uri(testclient, client, logged_admin, other_client): +def test_client_edit_invalid_uri(testclient, client, logged_admin, trusted_client): res = testclient.get("/admin/client/edit/" + client.client_id) res.forms["clientaddform"]["client_uri"] = "invalid" res = res.forms["clientaddform"].submit(status=200, name="action", value="edit") diff --git a/tests/oidc/test_hybrid_flow.py b/tests/oidc/test_hybrid_flow.py index e8bffc1a..911f7339 100644 --- a/tests/oidc/test_hybrid_flow.py +++ b/tests/oidc/test_hybrid_flow.py @@ -46,7 +46,7 @@ def test_oauth_hybrid(testclient, backend, user, client): assert res.json["name"] == "John (johnny) Doe" -def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_client): +def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, trusted_client): res = testclient.get( "/oauth/authorize", params=dict( @@ -75,7 +75,7 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_cl claims = jwt.decode(id_token, keypair[1]) assert logged_user.user_name == claims["sub"] assert logged_user.formatted_name == claims["name"] - assert [client.client_id, other_client.client_id] == claims["aud"] + assert [client.client_id, trusted_client.client_id] == claims["aud"] res = testclient.get( "/oauth/userinfo", diff --git a/tests/oidc/test_implicit_flow.py b/tests/oidc/test_implicit_flow.py index 4cc3d945..567c5c10 100644 --- a/tests/oidc/test_implicit_flow.py +++ b/tests/oidc/test_implicit_flow.py @@ -50,7 +50,7 @@ def test_oauth_implicit(testclient, user, client): client.save() -def test_oidc_implicit(testclient, keypair, user, client, other_client): +def test_oidc_implicit(testclient, keypair, user, client, trusted_client): client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" @@ -88,7 +88,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client): claims = jwt.decode(id_token, keypair[1]) assert user.user_name == claims["sub"] assert user.formatted_name == claims["name"] - assert [client.client_id, other_client.client_id] == claims["aud"] + assert [client.client_id, trusted_client.client_id] == claims["aud"] res = testclient.get( "/oauth/userinfo", @@ -104,7 +104,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client): def test_oidc_implicit_with_group( - testclient, keypair, user, client, foo_group, other_client + testclient, keypair, user, client, foo_group, trusted_client ): client.grant_types = ["token id_token"] client.token_endpoint_auth_method = "none" @@ -143,7 +143,7 @@ def test_oidc_implicit_with_group( claims = jwt.decode(id_token, keypair[1]) assert user.user_name == claims["sub"] assert user.formatted_name == claims["name"] - assert [client.client_id, other_client.client_id] == claims["aud"] + assert [client.client_id, trusted_client.client_id] == claims["aud"] assert ["foo"] == claims["groups"] res = testclient.get( diff --git a/tests/oidc/test_token_introspection.py b/tests/oidc/test_token_introspection.py index adf149ab..c3334a6b 100644 --- a/tests/oidc/test_token_introspection.py +++ b/tests/oidc/test_token_introspection.py @@ -58,7 +58,7 @@ def test_token_invalid(testclient, client): assert {"active": False} == res.json -def test_full_flow(testclient, logged_user, client, user, other_client): +def test_full_flow(testclient, logged_user, client, user, trusted_client): res = testclient.get( "/oauth/authorize", params=dict( @@ -103,7 +103,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client): headers={"Authorization": f"Basic {client_credentials(client)}"}, status=200, ) - assert set(res.json["aud"]) == {client.client_id, other_client.client_id} + assert set(res.json["aud"]) == {client.client_id, trusted_client.client_id} assert res.json["active"] assert res.json["client_id"] == client.client_id assert res.json["token_type"] == token.type diff --git a/tests/oidc/test_well_known.py b/tests/oidc/test_well_known.py index 6ad2c2e2..cd7fea68 100644 --- a/tests/oidc/test_well_known.py +++ b/tests/oidc/test_well_known.py @@ -100,3 +100,9 @@ def test_openid_configuration(testclient): "userinfo_endpoint": "http://canaille.test/oauth/userinfo", "prompt_values_supported": ["none"], } + + +def test_openid_configuration_prompt_value_create(testclient): + testclient.app.config["ENABLE_REGISTRATION"] = True + res = testclient.get("/.well-known/openid-configuration", status=200).json + assert "create" in res["prompt_values_supported"]