diff --git a/CHANGES.rst b/CHANGES.rst
index 66a72174..a86e5811 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -3,6 +3,11 @@ All notable changes to this project will be documented in this file.
The format is based on `Keep a Changelog `_,
and this project adheres to `Semantic Versioning `_.
+Added
+*****
+
+- OIDC `prompt=create` support. :issue:`185` :pr:`164`
+
Fixed
*****
@@ -13,6 +18,7 @@ Fixed
Added
*****
+
- ``THEME`` can be a relative path
[0.0.39] - 2023-12-15
diff --git a/canaille/oidc/endpoints.py b/canaille/oidc/endpoints.py
index 2796cac8..e09dc1a1 100644
--- a/canaille/oidc/endpoints.py
+++ b/canaille/oidc/endpoints.py
@@ -34,6 +34,7 @@
from .oauth import require_oauth
from .oauth import RevocationEndpoint
from .utils import SCOPE_DETAILS
+from .well_known import openid_configuration
bp = Blueprint("endpoints", __name__, url_prefix="/oauth")
@@ -54,6 +55,23 @@ def authorize():
if not client:
abort(400, "Invalid client.")
+ # https://openid.net/specs/openid-connect-prompt-create-1_0.html#name-authorization-request
+ # If the OpenID Provider receives a prompt value that it does
+ # not support (not declared in the prompt_values_supported
+ # metadata field) the OP SHOULD respond with an HTTP 400 (Bad
+ # Request) status code and an error value of invalid_request.
+ # It is RECOMMENDED that the OP return an error_description
+ # value identifying the invalid parameter value.
+ if (
+ request.args.get("prompt")
+ and request.args["prompt"]
+ not in openid_configuration()["prompt_values_supported"]
+ ):
+ return {
+ "error": "invalid_request",
+ "error_description": f"prompt '{request.args['prompt'] }' value is not supported",
+ }, 400
+
user = current_user()
requested_scopes = request.args.get("scope", "").split(" ")
allowed_scopes = client.get_allowed_scope(requested_scopes).split(" ")
@@ -65,6 +83,10 @@ def authorize():
return jsonify({"error": "login_required"})
session["redirect-after-login"] = request.url
+
+ if request.args.get("prompt") == "create":
+ return redirect(url_for("core.account.join"))
+
return redirect(url_for("core.auth.login"))
if not user.can_use_oidc:
diff --git a/canaille/oidc/well_known.py b/canaille/oidc/well_known.py
index 72704341..df03fd23 100644
--- a/canaille/oidc/well_known.py
+++ b/canaille/oidc/well_known.py
@@ -1,4 +1,5 @@
from flask import Blueprint
+from flask import current_app
from flask import g
from flask import jsonify
from flask import request
@@ -76,7 +77,8 @@ def openid_configuration():
],
"subject_types_supported": ["pairwise", "public"],
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"],
- "prompt_values_supported": ["none"],
+ "prompt_values_supported": ["none"]
+ + (["create"] if current_app.config.get("ENABLE_REGISTRATION") else []),
}
diff --git a/doc/specifications.rst b/doc/specifications.rst
index 8b5c6851..c61e5a61 100644
--- a/doc/specifications.rst
+++ b/doc/specifications.rst
@@ -40,7 +40,7 @@ OpenID Connect
- ❌ `OpenID Connect Back Channel Logout `_
- ❌ `OpenID Connect Back Channel Authentication Flow `_
- ❌ `OpenID Connect Core Error Code unmet_authentication_requirements `_
-- ❌ `Initiating User Registration via OpenID Connect 1.0 `_
+- ✅ `Initiating User Registration via OpenID Connect 1.0 `_
Comparison with other providers
===============================
diff --git a/tests/core/test_registration.py b/tests/core/test_registration.py
index c7f39305..bd5080d3 100644
--- a/tests/core/test_registration.py
+++ b/tests/core/test_registration.py
@@ -71,6 +71,10 @@ def test_registration_with_email_validation(testclient, backend, smtpd):
res.form["family_name"] = "newuser"
res = res.form.submit()
+ assert res.flashes == [
+ ("success", "Your account has been created successfully."),
+ ]
+
user = models.User.get()
assert user
user.delete()
diff --git a/tests/oidc/conftest.py b/tests/oidc/conftest.py
index ed5d3fc8..969a4fbb 100644
--- a/tests/oidc/conftest.py
+++ b/tests/oidc/conftest.py
@@ -41,7 +41,7 @@ def configuration(configuration, keypair):
@pytest.fixture
-def client(testclient, other_client, backend):
+def client(testclient, trusted_client, backend):
c = models.Client(
client_id=gen_salt(24),
client_name="Some client",
@@ -69,7 +69,7 @@ def client(testclient, other_client, backend):
token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://mydomain.tld/disconnected"],
)
- c.audience = [c, other_client]
+ c.audience = [c, trusted_client]
c.save()
yield c
@@ -77,7 +77,7 @@ def client(testclient, other_client, backend):
@pytest.fixture
-def other_client(testclient, backend):
+def trusted_client(testclient, backend):
c = models.Client(
client_id=gen_salt(24),
client_name="Some other client",
@@ -104,6 +104,7 @@ def other_client(testclient, backend):
jwks_uri="https://myotherdomain.tld/jwk",
token_endpoint_auth_method="client_secret_basic",
post_logout_redirect_uris=["https://myotherdomain.tld/disconnected"],
+ preconsent=True,
)
c.audience = [c]
c.save()
diff --git a/tests/oidc/test_authorization_code_flow.py b/tests/oidc/test_authorization_code_flow.py
index 8af21eb5..dc2a5cd3 100644
--- a/tests/oidc/test_authorization_code_flow.py
+++ b/tests/oidc/test_authorization_code_flow.py
@@ -13,7 +13,7 @@
def test_authorization_code_flow(
- testclient, logged_user, client, keypair, other_client
+ testclient, logged_user, client, keypair, trusted_client
):
assert not models.Consent.query()
@@ -81,13 +81,13 @@ def test_authorization_code_flow(
claims = jwt.decode(access_token, keypair[1])
assert claims["sub"] == logged_user.user_name
assert claims["name"] == logged_user.formatted_name
- assert claims["aud"] == [client.client_id, other_client.client_id]
+ assert claims["aud"] == [client.client_id, trusted_client.client_id]
id_token = res.json["id_token"]
claims = jwt.decode(id_token, keypair[1])
assert claims["sub"] == logged_user.user_name
assert claims["name"] == logged_user.formatted_name
- assert claims["aud"] == [client.client_id, other_client.client_id]
+ assert claims["aud"] == [client.client_id, trusted_client.client_id]
res = testclient.get(
"/oauth/userinfo",
@@ -114,7 +114,7 @@ def test_invalid_client(testclient, logged_user, keypair):
def test_authorization_code_flow_with_redirect_uri(
- testclient, logged_user, client, keypair, other_client
+ testclient, logged_user, client, keypair, trusted_client
):
assert not models.Consent.query()
@@ -161,7 +161,7 @@ def test_authorization_code_flow_with_redirect_uri(
def test_authorization_code_flow_preconsented(
- testclient, logged_user, client, keypair, other_client
+ testclient, logged_user, client, keypair, trusted_client
):
assert not models.Consent.query()
@@ -209,7 +209,7 @@ def test_authorization_code_flow_preconsented(
claims = jwt.decode(id_token, keypair[1])
assert logged_user.user_name == claims["sub"]
assert logged_user.formatted_name == claims["name"]
- assert [client.client_id, other_client.client_id] == claims["aud"]
+ assert [client.client_id, trusted_client.client_id] == claims["aud"]
res = testclient.get(
"/oauth/userinfo",
@@ -584,7 +584,7 @@ def test_authorization_code_flow_when_consent_already_given_but_for_a_smaller_sc
def test_authorization_code_flow_but_user_cannot_use_oidc(
- testclient, user, client, keypair, other_client
+ testclient, user, client, keypair, trusted_client
):
testclient.app.config["ACL"]["DEFAULT"]["PERMISSIONS"] = []
user.reload()
@@ -645,16 +645,17 @@ def test_nonce_not_required_in_oauth_requests(testclient, logged_user, client):
def test_authorization_code_request_scope_too_large(
- testclient, logged_user, keypair, other_client
+ testclient, logged_user, keypair, client
):
assert not models.Consent.query()
- assert "email" not in other_client.scope
+ client.scope = ["openid", "profile", "groups"]
+ client.save()
res = testclient.get(
"/oauth/authorize",
params=dict(
response_type="code",
- client_id=other_client.client_id,
+ client_id=client.client_id,
scope="openid profile email",
nonce="somenonce",
),
@@ -671,7 +672,7 @@ def test_authorization_code_request_scope_too_large(
"profile",
}
- consents = models.Consent.query(client=other_client, subject=logged_user)
+ consents = models.Consent.query(client=client, subject=logged_user)
assert set(consents[0].scope) == {
"openid",
"profile",
@@ -683,15 +684,15 @@ def test_authorization_code_request_scope_too_large(
grant_type="authorization_code",
code=code,
scope="openid profile email groups address phone",
- redirect_uri=other_client.redirect_uris[0],
+ redirect_uri=client.redirect_uris[0],
),
- headers={"Authorization": f"Basic {client_credentials(other_client)}"},
+ headers={"Authorization": f"Basic {client_credentials(client)}"},
status=200,
)
access_token = res.json["access_token"]
token = models.Token.get(access_token=access_token)
- assert token.client == other_client
+ assert token.client == client
assert token.subject == logged_user
assert set(token.scope) == {
"openid",
diff --git a/tests/oidc/test_authorization_prompt.py b/tests/oidc/test_authorization_prompt.py
index ac337843..b0f53652 100644
--- a/tests/oidc/test_authorization_prompt.py
+++ b/tests/oidc/test_authorization_prompt.py
@@ -2,11 +2,14 @@
Tests the behavior of Canaille depending on the OIDC 'prompt' parameter.
https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
"""
+import datetime
import uuid
from urllib.parse import parse_qs
from urllib.parse import urlsplit
from canaille.app import models
+from canaille.core.account import RegistrationPayload
+from flask import url_for
def test_prompt_none(testclient, logged_user, client):
@@ -98,3 +101,125 @@ def test_prompt_no_consent(testclient, logged_user, client):
status=200,
)
assert "consent_required" == res.json.get("error")
+
+
+def test_prompt_create_logged(testclient, logged_user, client):
+ """
+ If prompt=create and user is already logged in,
+ then go straight to the consent page.
+ """
+ testclient.app.config["ENABLE_REGISTRATION"] = True
+
+ consent = models.Consent(
+ consent_id=str(uuid.uuid4()),
+ client=client,
+ subject=logged_user,
+ scope=["openid", "profile"],
+ )
+ consent.save()
+
+ res = testclient.get(
+ "/oauth/authorize",
+ params=dict(
+ response_type="code",
+ client_id=client.client_id,
+ scope="openid profile",
+ nonce="somenonce",
+ prompt="create",
+ ),
+ status=302,
+ )
+ assert res.location.startswith(client.redirect_uris[0])
+
+ consent.delete()
+
+
+def test_prompt_create_registration_disabled(testclient, trusted_client, smtpd):
+ """
+ If prompt=create but Canaille registration is disabled,
+ an error response should be returned.
+
+ If the OpenID Provider receives a prompt value that it does
+ not support (not declared in the prompt_values_supported
+ metadata field) the OP SHOULD respond with an HTTP 400 (Bad
+ Request) status code and an error value of invalid_request.
+ It is RECOMMENDED that the OP return an error_description
+ value identifying the invalid parameter value.
+ """
+ res = testclient.get(
+ "/oauth/authorize",
+ params=dict(
+ response_type="code",
+ client_id=trusted_client.client_id,
+ scope="openid profile",
+ nonce="somenonce",
+ prompt="create",
+ ),
+ status=400,
+ )
+ assert res.json == {
+ "error": "invalid_request",
+ "error_description": "prompt 'create' value is not supported",
+ }
+
+
+def test_prompt_create_not_logged(testclient, trusted_client, smtpd):
+ """
+ If prompt=create and user is not logged in,
+ then display the registration form.
+ Check that the user is correctly redirected to
+ the client page after the registration process.
+ """
+ testclient.app.config["ENABLE_REGISTRATION"] = True
+
+ res = testclient.get(
+ "/oauth/authorize",
+ params=dict(
+ response_type="code",
+ client_id=trusted_client.client_id,
+ scope="openid profile",
+ nonce="somenonce",
+ prompt="create",
+ ),
+ )
+
+ # Display the registration form
+ res = res.follow()
+ res.form["email"] = "foo@bar.com"
+ res = res.form.submit()
+
+ # Checks the registration mail is sent
+ assert len(smtpd.messages) == 1
+
+ # Simulate a click on the validation link in the mail
+ payload = RegistrationPayload(
+ creation_date_isoformat=datetime.datetime.now(
+ datetime.timezone.utc
+ ).isoformat(),
+ user_name="",
+ user_name_editable=True,
+ email="foo@bar.com",
+ groups=[],
+ )
+ registration_url = url_for(
+ "core.account.registration",
+ data=payload.b64(),
+ hash=payload.build_hash(),
+ _external=True,
+ )
+
+ # Fill the user creation form
+ res = testclient.get(registration_url)
+ res.form["user_name"] = "newuser"
+ res.form["password1"] = "password"
+ res.form["password2"] = "password"
+ res.form["family_name"] = "newuser"
+ res = res.form.submit()
+
+ assert res.flashes == [
+ ("success", "Your account has been created successfully."),
+ ]
+
+ # Return to the client
+ res = res.follow()
+ assert res.location.startswith(trusted_client.redirect_uris[0])
diff --git a/tests/oidc/test_client_admin.py b/tests/oidc/test_client_admin.py
index 6b42ff42..cfcef631 100644
--- a/tests/oidc/test_client_admin.py
+++ b/tests/oidc/test_client_admin.py
@@ -21,7 +21,7 @@ def test_client_list(testclient, client, logged_admin):
res.mustcontain(client.client_name)
-def test_client_list_pagination(testclient, logged_admin, client, other_client):
+def test_client_list_pagination(testclient, logged_admin, client, trusted_client):
res = testclient.get("/admin/client")
res.mustcontain("2 items")
clients = []
@@ -67,18 +67,18 @@ def test_client_list_bad_pages(testclient, logged_admin):
)
-def test_client_list_search(testclient, logged_admin, client, other_client):
+def test_client_list_search(testclient, logged_admin, client, trusted_client):
res = testclient.get("/admin/client")
res.mustcontain("2 items")
res.mustcontain(client.client_name)
- res.mustcontain(other_client.client_name)
+ res.mustcontain(trusted_client.client_name)
form = res.forms["search"]
form["query"] = "other"
res = form.submit()
res.mustcontain("1 item")
- res.mustcontain(other_client.client_name)
+ res.mustcontain(trusted_client.client_name)
res.mustcontain(no=client.client_name)
@@ -144,7 +144,7 @@ def test_add_missing_fields(testclient, logged_admin):
) in res.flashes
-def test_client_edit(testclient, client, logged_admin, other_client):
+def test_client_edit(testclient, client, logged_admin, trusted_client):
res = testclient.get("/admin/client/edit/" + client.client_id)
data = {
"client_name": "foobar",
@@ -162,7 +162,7 @@ def test_client_edit(testclient, client, logged_admin, other_client):
"software_version": "1",
"jwk": "jwk",
"jwks_uri": "https://foo.bar/jwks.json",
- "audience": [client.id, other_client.id],
+ "audience": [client.id, trusted_client.id],
"preconsent": True,
"post_logout_redirect_uris-0": "https://foo.bar/disconnected",
}
@@ -196,12 +196,12 @@ def test_client_edit(testclient, client, logged_admin, other_client):
assert client.software_version == "1"
assert client.jwk == "jwk"
assert client.jwks_uri == "https://foo.bar/jwks.json"
- assert client.audience == [client, other_client]
+ assert client.audience == [client, trusted_client]
assert not client.preconsent
assert client.post_logout_redirect_uris == ["https://foo.bar/disconnected"]
-def test_client_edit_missing_fields(testclient, client, logged_admin, other_client):
+def test_client_edit_missing_fields(testclient, client, logged_admin, trusted_client):
res = testclient.get("/admin/client/edit/" + client.client_id)
res.forms["clientaddform"]["client_name"] = ""
res = res.forms["clientaddform"].submit(name="action", value="edit")
@@ -255,7 +255,7 @@ def test_client_delete_invalid_client(testclient, logged_admin, client):
)
-def test_client_edit_preauth(testclient, client, logged_admin, other_client):
+def test_client_edit_preauth(testclient, client, logged_admin, trusted_client):
assert not client.preconsent
res = testclient.get("/admin/client/edit/" + client.client_id)
@@ -275,7 +275,7 @@ def test_client_edit_preauth(testclient, client, logged_admin, other_client):
assert not client.preconsent
-def test_client_edit_invalid_uri(testclient, client, logged_admin, other_client):
+def test_client_edit_invalid_uri(testclient, client, logged_admin, trusted_client):
res = testclient.get("/admin/client/edit/" + client.client_id)
res.forms["clientaddform"]["client_uri"] = "invalid"
res = res.forms["clientaddform"].submit(status=200, name="action", value="edit")
diff --git a/tests/oidc/test_hybrid_flow.py b/tests/oidc/test_hybrid_flow.py
index e8bffc1a..911f7339 100644
--- a/tests/oidc/test_hybrid_flow.py
+++ b/tests/oidc/test_hybrid_flow.py
@@ -46,7 +46,7 @@ def test_oauth_hybrid(testclient, backend, user, client):
assert res.json["name"] == "John (johnny) Doe"
-def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_client):
+def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, trusted_client):
res = testclient.get(
"/oauth/authorize",
params=dict(
@@ -75,7 +75,7 @@ def test_oidc_hybrid(testclient, backend, logged_user, client, keypair, other_cl
claims = jwt.decode(id_token, keypair[1])
assert logged_user.user_name == claims["sub"]
assert logged_user.formatted_name == claims["name"]
- assert [client.client_id, other_client.client_id] == claims["aud"]
+ assert [client.client_id, trusted_client.client_id] == claims["aud"]
res = testclient.get(
"/oauth/userinfo",
diff --git a/tests/oidc/test_implicit_flow.py b/tests/oidc/test_implicit_flow.py
index 4cc3d945..567c5c10 100644
--- a/tests/oidc/test_implicit_flow.py
+++ b/tests/oidc/test_implicit_flow.py
@@ -50,7 +50,7 @@ def test_oauth_implicit(testclient, user, client):
client.save()
-def test_oidc_implicit(testclient, keypair, user, client, other_client):
+def test_oidc_implicit(testclient, keypair, user, client, trusted_client):
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
@@ -88,7 +88,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client):
claims = jwt.decode(id_token, keypair[1])
assert user.user_name == claims["sub"]
assert user.formatted_name == claims["name"]
- assert [client.client_id, other_client.client_id] == claims["aud"]
+ assert [client.client_id, trusted_client.client_id] == claims["aud"]
res = testclient.get(
"/oauth/userinfo",
@@ -104,7 +104,7 @@ def test_oidc_implicit(testclient, keypair, user, client, other_client):
def test_oidc_implicit_with_group(
- testclient, keypair, user, client, foo_group, other_client
+ testclient, keypair, user, client, foo_group, trusted_client
):
client.grant_types = ["token id_token"]
client.token_endpoint_auth_method = "none"
@@ -143,7 +143,7 @@ def test_oidc_implicit_with_group(
claims = jwt.decode(id_token, keypair[1])
assert user.user_name == claims["sub"]
assert user.formatted_name == claims["name"]
- assert [client.client_id, other_client.client_id] == claims["aud"]
+ assert [client.client_id, trusted_client.client_id] == claims["aud"]
assert ["foo"] == claims["groups"]
res = testclient.get(
diff --git a/tests/oidc/test_token_introspection.py b/tests/oidc/test_token_introspection.py
index adf149ab..c3334a6b 100644
--- a/tests/oidc/test_token_introspection.py
+++ b/tests/oidc/test_token_introspection.py
@@ -58,7 +58,7 @@ def test_token_invalid(testclient, client):
assert {"active": False} == res.json
-def test_full_flow(testclient, logged_user, client, user, other_client):
+def test_full_flow(testclient, logged_user, client, user, trusted_client):
res = testclient.get(
"/oauth/authorize",
params=dict(
@@ -103,7 +103,7 @@ def test_full_flow(testclient, logged_user, client, user, other_client):
headers={"Authorization": f"Basic {client_credentials(client)}"},
status=200,
)
- assert set(res.json["aud"]) == {client.client_id, other_client.client_id}
+ assert set(res.json["aud"]) == {client.client_id, trusted_client.client_id}
assert res.json["active"]
assert res.json["client_id"] == client.client_id
assert res.json["token_type"] == token.type
diff --git a/tests/oidc/test_well_known.py b/tests/oidc/test_well_known.py
index 6ad2c2e2..cd7fea68 100644
--- a/tests/oidc/test_well_known.py
+++ b/tests/oidc/test_well_known.py
@@ -100,3 +100,9 @@ def test_openid_configuration(testclient):
"userinfo_endpoint": "http://canaille.test/oauth/userinfo",
"prompt_values_supported": ["none"],
}
+
+
+def test_openid_configuration_prompt_value_create(testclient):
+ testclient.app.config["ENABLE_REGISTRATION"] = True
+ res = testclient.get("/.well-known/openid-configuration", status=200).json
+ assert "create" in res["prompt_values_supported"]