Skip to content

Commit

Permalink
Merge pull request #1 from SUNET/per-requester-allowed-attributes
Browse files Browse the repository at this point in the history
Specify allowed attributes for given service names
  • Loading branch information
c00kiemon5ter authored Feb 2, 2023
2 parents 7607161 + 7178eb5 commit 2e6e06b
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions src/swamid_plugins/attributes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from logging import getLogger as get_logger

from jinja2 import Environment
Expand All @@ -22,15 +23,49 @@ def __init__(self, config, internal_attributes, *args, **kwargs):
self.state_result = config["state_result"]
self.attributes_strategy = {'all': all, 'any': any}[config.get("attributes_strategy", "all")]
self.attribute_values_strategy = {'all': all, 'any': any}[config.get("attribute_values_strategy", "any")]
self.allowed_attributes = config.get("allowed_attributes")
self.user_id_attribute = config["user_id_attribute"]

required_attributes_per_service = {}
for item in config.get("required_attributes_per_service", []):
services = item.get("services")
allowed_attributes = item.get("allowed_attributes")
if services is None or allowed_attributes is None:
logger.warning["services or allowed_attributes missing"]
continue
for service in services:
required_attributes = required_attributes_per_service.get(
service, defaultdict(set)
)
for k, v in allowed_attributes.items():
required_attributes[k].update(v)
required_attributes_per_service[service] = required_attributes

# the attributes for the "default" service are added for all services
self.default_attributes = dict(
required_attributes_per_service.get("default", {})
)
for k, v in self.default_attributes.items():
for required_attributes in required_attributes_per_service.values():
required_attributes[k].update(v)

self.required_attributes_per_service = {
k: dict(v) for k, v in required_attributes_per_service.items()
if k != "default"
}

templates_dir_path = config["templates_dir_path"]
self.tpl_env = Environment(loader=FileSystemLoader(templates_dir_path), autoescape=select_autoescape())

def process(self, context, internal_data):
try:
return self._process(context, internal_data)
allowed_attributes = self.required_attributes_per_service.get(
internal_data.requester, self.default_attributes
)
if not allowed_attributes:
raise AttributeCheckerError(
"No allowed attributes configured for %s", internal_data.requester
)
return self._process(context, internal_data, allowed_attributes)
except AttributeCheckerError as e:
context.state[self.state_result] = False
context.state.delete = True
Expand All @@ -53,20 +88,20 @@ def process(self, context, internal_data):
)
return UnauthorizedResponse(content)

def _process(self, context, internal_data):
def _process(self, context, internal_data, allowed_attributes):
context.state[self.state_result] = False
is_authorized = self.attributes_strategy(
self.attribute_values_strategy(
value in values
for value in internal_data.attributes.get(attr, [])
)
for attr, values in self.allowed_attributes.items()
for attr, values in allowed_attributes.items()
)

if not is_authorized:
error_context = {
'message': 'User is not authorized to access this service.',
'allowed_attributes': list(self.allowed_attributes.keys()),
'allowed_attributes': list(allowed_attributes.keys()),
'attributes': internal_data.attributes.keys(),
'attributes_strategy': self.attributes_strategy.__name__,
'attribute_values_strategy': self.attribute_values_strategy.__name__,
Expand Down

0 comments on commit 2e6e06b

Please sign in to comment.