diff --git a/src/swamid_plugins/attributes/__init__.py b/src/swamid_plugins/attributes/__init__.py index cd82200..0a8bc83 100644 --- a/src/swamid_plugins/attributes/__init__.py +++ b/src/swamid_plugins/attributes/__init__.py @@ -1,3 +1,4 @@ +from collections import defaultdict from logging import getLogger as get_logger from jinja2 import Environment @@ -22,15 +23,49 @@ def __init__(self, config, internal_attributes, *args, **kwargs): self.state_result = config["state_result"] self.attributes_strategy = {'all': all, 'any': any}[config.get("attributes_strategy", "all")] self.attribute_values_strategy = {'all': all, 'any': any}[config.get("attribute_values_strategy", "any")] - self.allowed_attributes = config.get("allowed_attributes") self.user_id_attribute = config["user_id_attribute"] + required_attributes_per_service = {} + for item in config.get("required_attributes_per_service", []): + services = item.get("services") + allowed_attributes = item.get("allowed_attributes") + if services is None or allowed_attributes is None: + logger.warning["services or allowed_attributes missing"] + continue + for service in services: + required_attributes = required_attributes_per_service.get( + service, defaultdict(set) + ) + for k, v in allowed_attributes.items(): + required_attributes[k].update(v) + required_attributes_per_service[service] = required_attributes + + # the attributes for the "default" service are added for all services + self.default_attributes = dict( + required_attributes_per_service.get("default", {}) + ) + for k, v in self.default_attributes.items(): + for required_attributes in required_attributes_per_service.values(): + required_attributes[k].update(v) + + self.required_attributes_per_service = { + k: dict(v) for k, v in required_attributes_per_service.items() + if k != "default" + } + templates_dir_path = config["templates_dir_path"] self.tpl_env = Environment(loader=FileSystemLoader(templates_dir_path), autoescape=select_autoescape()) def process(self, context, internal_data): try: - return self._process(context, internal_data) + allowed_attributes = self.required_attributes_per_service.get( + internal_data.requester, self.default_attributes + ) + if not allowed_attributes: + raise AttributeCheckerError( + "No allowed attributes configured for %s", internal_data.requester + ) + return self._process(context, internal_data, allowed_attributes) except AttributeCheckerError as e: context.state[self.state_result] = False context.state.delete = True @@ -53,20 +88,20 @@ def process(self, context, internal_data): ) return UnauthorizedResponse(content) - def _process(self, context, internal_data): + def _process(self, context, internal_data, allowed_attributes): context.state[self.state_result] = False is_authorized = self.attributes_strategy( self.attribute_values_strategy( value in values for value in internal_data.attributes.get(attr, []) ) - for attr, values in self.allowed_attributes.items() + for attr, values in allowed_attributes.items() ) if not is_authorized: error_context = { 'message': 'User is not authorized to access this service.', - 'allowed_attributes': list(self.allowed_attributes.keys()), + 'allowed_attributes': list(allowed_attributes.keys()), 'attributes': internal_data.attributes.keys(), 'attributes_strategy': self.attributes_strategy.__name__, 'attribute_values_strategy': self.attribute_values_strategy.__name__,