Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support paths in BASE during routing #451

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ config:
sub_hash_salt: randomSALTvalue

provider:
issuer: <base_url>
client_registration_supported: Yes
response_types_supported: ["code", "id_token token"]
subject_types_supported: ["pairwise"]
Expand Down
8 changes: 7 additions & 1 deletion src/satosa/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""

from ..attribute_mapping import AttributeMapper
from ..util import join_paths

from urllib.parse import urlparse


class BackendModule(object):
Expand All @@ -29,8 +32,11 @@ def __init__(self, auth_callback_func, internal_attributes, base_url, name):
self.auth_callback_func = auth_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url.rstrip("/") if base_url else ""
self.base_path = urlparse(self.base_url).path.lstrip("/")
self.name = name
self.endpoint_baseurl = join_paths(self.base_url, self.name)
self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/")

def start_auth(self, context, internal_request):
"""
Expand Down
9 changes: 7 additions & 2 deletions src/satosa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid

from saml2.s_utils import UnknownSystemEntity
from urllib.parse import urlparse

from satosa import util
from satosa.response import BadRequest
Expand Down Expand Up @@ -52,6 +53,8 @@ def __init__(self, config):
"""
self.config = config

base_path = urlparse(self.config["BASE"]).path.lstrip("/")

logger.info("Loading backend modules...")
backends = load_backends(self.config, self._auth_resp_callback_func,
self.config["INTERNAL_ATTRIBUTES"])
Expand All @@ -77,8 +80,10 @@ def __init__(self, config):
self.config["BASE"]))
self._link_micro_services(self.response_micro_services, self._auth_resp_finish)

self.module_router = ModuleRouter(frontends, backends,
self.request_micro_services + self.response_micro_services)
self.module_router = ModuleRouter(frontends,
backends,
self.request_micro_services + self.response_micro_services,
base_path)

def _link_micro_services(self, micro_services, finisher):
if not micro_services:
Expand Down
4 changes: 0 additions & 4 deletions src/satosa/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def path(self, p):
raise ValueError("path can't start with '/'")
self._path = p

def target_entity_id_from_path(self):
target_entity_id = self.path.split("/")[1]
return target_entity_id

def decorate(self, key, value):
"""
Add information to the context
Expand Down
11 changes: 10 additions & 1 deletion src/satosa/frontends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Holds a base class for frontend modules used in the SATOSA proxy.
"""
from ..attribute_mapping import AttributeMapper
from ..util import join_paths

from urllib.parse import urlparse


class FrontendModule(object):
Expand All @@ -14,17 +17,23 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name):
:type auth_req_callback_func:
(satosa.context.Context, satosa.internal.InternalData) -> satosa.response.Response
:type internal_attributes: dict[str, dict[str, str | list[str]]]
:type base_url: str
:type name: str

:param auth_req_callback_func: Callback should be called by the module after the
authorization response has been processed.
:param internal_attributes: attribute mapping
:param base_url: base url of the proxy
:param name: name of the plugin
"""
self.auth_req_callback_func = auth_req_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url or ""
self.base_path = urlparse(self.base_url).path.lstrip("/")
self.name = name
self.endpoint_baseurl = join_paths(self.base_url, self.name)
self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/")

def handle_authn_response(self, context, internal_resp):
"""
Expand Down
50 changes: 36 additions & 14 deletions src/satosa/frontends/openid_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..response import BadRequest, Created
from ..response import SeeOther, Response
from ..response import Unauthorized
from ..util import rndstr
from ..util import join_paths, rndstr

import satosa.logging_util as lu
from satosa.internal import InternalData
Expand All @@ -62,7 +62,8 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url,

self.config = conf
provider_config = self.config["provider"]
provider_config["issuer"] = base_url
if not provider_config.get("issuer"):
provider_config["issuer"] = base_url

self.signing_key = RSAKey(
key=rsa_load(self.config["signing_key_path"]),
Expand Down Expand Up @@ -97,7 +98,6 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url,
else:
cdb = {}

self.endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
self.provider = _create_provider(
provider_config,
self.endpoint_baseurl,
Expand Down Expand Up @@ -173,6 +173,19 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
# See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
#
# We skip the scheme + host + port of the issuer URL, because we can only map the
# path for the provider config endpoint. We are safe to use urlparse().path here,
# because for issuer OIDC allows only https URLs without query and fragment parts.
issuer = self.provider.configuration_information["issuer"]
autoconf_path = ".well-known/openid-configuration"
provider_config = (
"^{}$".format(join_paths(urlparse(issuer).path.lstrip("/"), autoconf_path)),
self.provider_config,
)
jwks_uri = ("^{}/jwks$".format(self.endpoint_basepath), self.jwks)

backend_name = None
if len(backend_names) != 1:
# only supports one backend since there currently is no way to publish multiple authorization endpoints
Expand All @@ -189,40 +202,49 @@ def register_endpoints(self, backend_names):
else:
backend_name = backend_names[0]

provider_config = ("^.well-known/openid-configuration$", self.provider_config)
jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)

if backend_name:
# if there is only one backend, include its name in the path so the default routing can work
auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url)
auth_endpoint = join_paths(
self.base_url,
backend_name,
self.name,
AuthorizationEndpoint.url,
)
self.provider.configuration_information["authorization_endpoint"] = auth_endpoint
auth_path = urlparse(auth_endpoint).path.lstrip("/")
else:
auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
auth_path = join_paths(self.endpoint_basepath, AuthorizationRequest.url)

authentication = ("^{}$".format(auth_path), self.handle_authn_request)
url_map = [provider_config, jwks_uri, authentication]

if any("code" in v for v in self.provider.configuration_information["response_types_supported"]):
self.provider.configuration_information["token_endpoint"] = "{}/{}".format(
self.endpoint_baseurl, TokenEndpoint.url
self.provider.configuration_information["token_endpoint"] = join_paths(
self.endpoint_baseurl,
TokenEndpoint.url,
)
token_endpoint = (
"^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint
"^{}".format(join_paths(self.endpoint_basepath, TokenEndpoint.url)),
self.token_endpoint,
)
url_map.append(token_endpoint)

self.provider.configuration_information["userinfo_endpoint"] = (
"{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url)
join_paths(self.endpoint_baseurl, UserinfoEndpoint.url)
)
userinfo_endpoint = (
"^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint
"^{}".format(
join_paths(self.endpoint_basepath, UserinfoEndpoint.url)
),
self.userinfo_endpoint,
)
url_map.append(userinfo_endpoint)

if "registration_endpoint" in self.provider.configuration_information:
client_registration = (
"^{}/{}".format(self.name, RegistrationEndpoint.url),
"^{}".format(
join_paths(self.endpoint_basepath, RegistrationEndpoint.url)
),
self.client_registration,
)
url_map.append(client_registration)
Expand Down
3 changes: 2 additions & 1 deletion src/satosa/frontends/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import satosa.logging_util as lu
from satosa.frontends.base import FrontendModule
from satosa.response import Response
from satosa.util import join_paths


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,7 +44,7 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
url_map = [("^{}".format(self.name), self.ping_endpoint)]
url_map = [("^{}".format(join_paths(self.endpoint_basepath, self.name)), self.ping_endpoint)]

return url_map

Expand Down
49 changes: 34 additions & 15 deletions src/satosa/frontends/saml2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..response import Response
from ..response import ServiceError
from ..saml_util import make_saml_response
from ..util import join_paths
from satosa.exception import SATOSAError
from satosa.exception import SATOSABadRequestError
from satosa.exception import SATOSAMissingStateError
Expand Down Expand Up @@ -119,7 +120,7 @@ def register_endpoints(self, backend_names):

if self.enable_metadata_reload():
url_map.append(
("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata))
("^%s/%s$" % (self.endpoint_basepath, "reload-metadata"), self._reload_metadata))

self.idp_config = self._build_idp_config_endpoints(
self.config[self.KEY_IDP_CONFIG], backend_names)
Expand Down Expand Up @@ -553,15 +554,20 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "(" + "|".join(providers) + ")"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = ""
for provider in providers:
valid_providers = "{}|^{}".format(valid_providers, provider)
valid_providers = valid_providers.lstrip("|")
parsed_endp = urlparse(endp)
url_map.append(("(%s)/%s$" % (valid_providers, parsed_endp.path),
functools.partial(self.handle_authn_request, binding_in=binding)))
endp_path = urlparse(endp).path
url_map.append(
(
"^{}$".format(
join_paths(self.base_path, backend_providers, endp_path)
),
functools.partial(
self.handle_authn_request, binding_in=binding
),
)
)

if self.expose_entityid_endpoint():
logger.debug("Exposing frontend entity endpoint = {}".format(self.idp.config.entityid))
Expand Down Expand Up @@ -717,11 +723,18 @@ def _load_idp_dynamic_endpoints(self, context):
:param context:
:return: An idp server
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
idp_conf_file = self._load_endpoints_to_config(context.target_backend, target_entity_id)
idp_config = IdPConfig().load(idp_conf_file)
return Server(config=idp_config)

def _target_entity_id_from_path(self, request_path):
path = request_path.lstrip("/")
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path and path.startswith(base_path):
path = path[len(base_path):].lstrip("/")
return path.split("/")[1]

def _load_idp_dynamic_entity_id(self, state):
"""
Loads an idp server with the entity id saved in state
Expand All @@ -747,7 +760,7 @@ def handle_authn_request(self, context, binding_in):
:type binding_in: str
:rtype: satosa.response.Response
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
target_entity_id = urlsafe_b64decode(target_entity_id).decode()
context.decorate(Context.KEY_TARGET_ENTITYID, target_entity_id)

Expand All @@ -765,7 +778,7 @@ def _create_state_data(self, context, resp_args, relay_state):
:rtype: dict[str, dict[str, str] | str]
"""
state = super()._create_state_data(context, resp_args, relay_state)
state["target_entity_id"] = context.target_entity_id_from_path()
state["target_entity_id"] = self._target_entity_id_from_path(context.path)
return state

def handle_backend_error(self, exception):
Expand Down Expand Up @@ -800,14 +813,20 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "(" + "|".join(providers) + ")"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = "|^".join(providers)
parsed_endp = urlparse(endp)
endp_path = urlparse(endp).path
url_map.append(
(
r"(^{})/\S+/{}".format(valid_providers, parsed_endp.path),
functools.partial(self.handle_authn_request, binding_in=binding)
"^{}$".format(
join_paths(
self.base_path, backend_providers, "\S+", endp_path
)
),
functools.partial(
self.handle_authn_request, binding_in=binding
),
)
)

Expand Down
10 changes: 9 additions & 1 deletion src/satosa/micro_services/account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..exception import SATOSAAuthenticationError
from ..micro_services.base import ResponseMicroService
from ..response import Redirect
from ..util import join_paths

import satosa.logging_util as lu
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -161,4 +162,11 @@ def register_endpoints(self):

:return: A list of endpoints bound to a function
"""
return [("^account_linking%s$" % self.endpoint, self._handle_al_response)]
return [
(
"^{}$".format(
join_paths(self.base_path, "account_linking", self.endpoint)
),
self._handle_al_response,
)
]
6 changes: 6 additions & 0 deletions src/satosa/micro_services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Micro service for SATOSA
"""
import logging
from urllib.parse import urlparse

from ..util import join_paths

logger = logging.getLogger(__name__)

Expand All @@ -14,6 +17,9 @@ class MicroService(object):
def __init__(self, name, base_url, **kwargs):
self.name = name
self.base_url = base_url
self.base_path = urlparse(base_url).path.lstrip("/")
self.endpoint_baseurl = join_paths(self.base_url, self.name)
self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/")
self.next = None

def process(self, context, data):
Expand Down
Loading