Skip to content

Commit

Permalink
feature: return list of (un)suppressed findings instead of just a suc…
Browse files Browse the repository at this point in the history
…cess boolean
  • Loading branch information
mlflr committed Oct 1, 2024
1 parent d31c1a9 commit 4a59532
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 80 deletions.
135 changes: 90 additions & 45 deletions awsfindingsmanagerlib/awsfindingsmanagerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class Finding:

def __init__(self, data: Dict) -> None:
self._data = self._validate_data(data)
self._logger = logging.getLogger(f'{LOGGER_BASENAME}.{self.__class__.__name__}')
self._logger = logging.getLogger(
f'{LOGGER_BASENAME}.{self.__class__.__name__}')
self._matched_rule = None

def __hash__(self) -> int:
Expand All @@ -104,7 +105,8 @@ def __ne__(self, other: Finding) -> bool:
def _validate_data(data: Dict) -> Dict:
missing = set(Finding.required_fields) - set(data.keys())
if missing:
raise InvalidFindingData(f'Missing required keys: "{missing}" for data with ID "{data.get("Id")}"')
raise InvalidFindingData(f'Missing required keys: "{
missing}" for data with ID "{data.get("Id")}"')
return data

@property
Expand All @@ -116,7 +118,8 @@ def matched_rule(self) -> Rule:
def matched_rule(self, rule) -> None:
"""The matched rule setter that is registered in the finding."""
if not isinstance(rule, Rule):
raise InvalidRuleType(f'The argument provided is not a valid rule object. Received: "{rule}"')
raise InvalidRuleType(
f'The argument provided is not a valid rule object. Received: "{rule}"')
self._matched_rule = rule

@property
Expand Down Expand Up @@ -277,7 +280,8 @@ def _parse_date_time(self, datetime_string) -> Optional[datetime]:
try:
return parse(datetime_string)
except ValueError:
self._logger.warning(f'Could not automatically parse datetime string: "{datetime_string}"')
self._logger.warning(f'Could not automatically parse datetime string: "{
datetime_string}"')
return None

@property
Expand Down Expand Up @@ -350,16 +354,20 @@ def is_matching_rule(self, rule: Rule) -> bool:
if not isinstance(rule, Rule):
raise InvalidRuleType(rule)
if any([
self.match_if_set(self.security_control_id, rule.security_control_id),
self.match_if_set(self.security_control_id,
rule.security_control_id),
self.match_if_set(self.control_id, rule.rule_or_control_id),
self.match_if_set(self.rule_id, rule.rule_or_control_id)
]):
self._logger.debug(f'Matched with rule "{rule.note}" on one of "control_id, security_control_id"')
self._logger.debug(f'Matched with rule "{
rule.note}" on one of "control_id, security_control_id"')
if not any([rule.tags, rule.resource_id_regexps]):
self._logger.debug(f'Rule "{rule.note}" does not seem to have filters for resources or tags.')
self._logger.debug(
f'Rule "{rule.note}" does not seem to have filters for resources or tags.')
return True
if any([self.is_matching_tags(rule.tags), self.is_matching_resource_ids(rule.resource_id_regexps)]):
self._logger.debug(f'Matched with rule "{rule.note}" either on resources or tags.')
self._logger.debug(f'Matched with rule "{
rule.note}" either on resources or tags.')
return True
return False

Expand All @@ -368,7 +376,8 @@ class Rule:
"""Models a suppression rule."""

def __init__(self, note: str, action: str, match_on: Dict) -> None:
self._data = validate_rule_data({'note': note, 'action': action, 'match_on': match_on})
self._data = validate_rule_data(
{'note': note, 'action': action, 'match_on': match_on})

def __hash__(self) -> int:
return hash(self.note)
Expand Down Expand Up @@ -508,15 +517,17 @@ def __init__(self,
denied_account_ids: Optional[List[str]] = None,
strict_mode: bool = True,
suppress_label: str = None):
self._logger = logging.getLogger(f'{LOGGER_BASENAME}.{self.__class__.__name__}')
self._logger = logging.getLogger(
f'{LOGGER_BASENAME}.{self.__class__.__name__}')
self.allowed_regions, self.denied_regions = validate_allowed_denied_regions(allowed_regions,
denied_regions)
self.allowed_account_ids, self.denied_account_ids = validate_allowed_denied_account_ids(allowed_account_ids,
denied_account_ids)
self.sts = self._get_sts_client()
self.ec2 = self._get_ec2_client(region)
self._aws_regions = None
self.aws_region = self._validate_region(region) or self._sts_client_config_region
self.aws_region = self._validate_region(
region) or self._sts_client_config_region
self._rules = set()
self._strict_mode = strict_mode
self._rules_errors = []
Expand Down Expand Up @@ -611,7 +622,8 @@ def _get_security_hub_client(region: str):
client = boto3.client('securityhub', **kwargs)
except (botocore.exceptions.NoRegionError,
botocore.exceptions.InvalidRegionError) as msg:
raise NoRegion(f'Security Hub client requires a valid region set to connect, message was:{msg}') from None
raise NoRegion(f'Security Hub client requires a valid region set to connect, message was: {
msg}') from None
return client

def _get_security_hub_paginator_iterator(self, region: str, operation_name: str, query_filter: dict):
Expand All @@ -631,7 +643,8 @@ def _get_ec2_client(region: str):
except (botocore.exceptions.NoRegionError,
botocore.exceptions.InvalidRegionError,
botocore.exceptions.EndpointConnectionError) as msg:
raise NoRegion(f'Ec2 client requires a valid region set to connect, message was:{msg}') from None
raise NoRegion(f'Ec2 client requires a valid region set to connect, message was: {
msg}') from None
except (botocore.exceptions.ClientError, botocore.exceptions.NoCredentialsError) as msg:
raise InvalidOrNoCredentials(msg) from None
return client
Expand All @@ -646,14 +659,20 @@ def regions(self):
self._aws_regions = [region.get('RegionName')
for region in self._describe_ec2_regions()
if region.get('OptInStatus', '') != 'not-opted-in']
self._logger.debug(f'Regions in EC2 that were opted in are : {self._aws_regions}')
self._logger.debug(f'Regions in EC2 that were opted in are: {
self._aws_regions}')
if self.allowed_regions:
self._aws_regions = set(self._aws_regions).intersection(set(self.allowed_regions))
self._logger.debug(f'Working on allowed regions {self._aws_regions}')
self._aws_regions = set(self._aws_regions).intersection(
set(self.allowed_regions))
self._logger.debug(f'Working on allowed regions {
self._aws_regions}')
elif self.denied_regions:
self._logger.debug(f'Excluding denied regions {self.denied_regions}')
self._aws_regions = set(self._aws_regions) - set(self.denied_regions)
self._logger.debug(f'Working on non-denied regions {self._aws_regions}')
self._logger.debug(f'Excluding denied regions {
self.denied_regions}')
self._aws_regions = set(self._aws_regions) - \
set(self.denied_regions)
self._logger.debug(
f'Working on non-denied regions {self._aws_regions}')
else:
self._logger.debug('Working on all regions')
return self._aws_regions
Expand All @@ -663,10 +682,12 @@ def _get_aggregating_region(self):
try:
client = self._get_security_hub_client(self.aws_region)
data = client.list_finding_aggregators()
aggregating_region = data.get('FindingAggregators')[0].get('FindingAggregatorArn').split(':')[3]
aggregating_region = data.get('FindingAggregators')[0].get(
'FindingAggregatorArn').split(':')[3]
self._logger.info(f'Found aggregating region {aggregating_region}')
except (IndexError, botocore.exceptions.ClientError):
self._logger.debug('Could not get aggregating region, either not set, or a client error')
self._logger.debug(
'Could not get aggregating region, either not set, or a client error')
return aggregating_region

@staticmethod
Expand All @@ -688,7 +709,8 @@ def _calculate_account_id_filter(allowed_account_ids: Optional[List[str]],
if any([allowed_account_ids, denied_account_ids]):
comparison = 'EQUALS' if allowed_account_ids else 'NOT_EQUALS'
iterator = allowed_account_ids if allowed_account_ids else denied_account_ids
aws_account_ids = [{'Comparison': comparison, 'Value': account} for account in iterator]
aws_account_ids = [{'Comparison': comparison,
'Value': account} for account in iterator]
return aws_account_ids

# pylint: disable=dangerous-default-value
Expand All @@ -711,7 +733,8 @@ def update_query_for_account_ids(query_filter: Dict = DEFAULT_SECURITY_HUB_FILTE
"""
query_filter = deepcopy(query_filter)
aws_account_ids = FindingsManager._calculate_account_id_filter(allowed_account_ids, denied_account_ids)
aws_account_ids = FindingsManager._calculate_account_id_filter(
allowed_account_ids, denied_account_ids)
if aws_account_ids:
query_filter.update({'AwsAccountId': aws_account_ids})
return query_filter
Expand All @@ -720,7 +743,8 @@ def update_query_for_account_ids(query_filter: Dict = DEFAULT_SECURITY_HUB_FILTE
def _get_findings(self, query_filter: Dict):
findings = set()
aggregating_region = self._get_aggregating_region()
regions_to_retrieve = [aggregating_region] if aggregating_region else self.regions
regions_to_retrieve = [
aggregating_region] if aggregating_region else self.regions
for region in regions_to_retrieve:
self._logger.debug(f'Trying to get findings for region {region}')
iterator = self._get_security_hub_paginator_iterator(
Expand All @@ -732,11 +756,13 @@ def _get_findings(self, query_filter: Dict):
for page in iterator:
for finding_data in page['Findings']:
finding = Finding(finding_data)
self._logger.debug(f'Adding finding with id {finding.id}')
self._logger.debug(
f'Adding finding with id {finding.id}')
findings.add(finding)
except botocore.exceptions.ClientError as error:
if error.response['Error']['Code'] in ['AccessDeniedException', 'InvalidAccessException']:
self._logger.debug(f'No access for Security Hub for region {region}.')
self._logger.debug(
f'No access for Security Hub for region {region}.')
continue
raise error
return list(findings)
Expand All @@ -750,7 +776,8 @@ def _get_matching_findings(rule: Rule, findings: List[Finding], logger: logging.
logger.debug(f'Following findings matched with rule with note: "{rule.note}", '
f'{[finding.id for finding in matching_findings]}')
else:
logger.debug('No resource id patterns or tags are provided in the rule, all findings used.')
logger.debug(
'No resource id patterns or tags are provided in the rule, all findings used.')
matching_findings = findings
for finding in matching_findings:
finding.matched_rule = rule
Expand All @@ -771,7 +798,8 @@ def get_findings(self) -> List[Finding]:
findings = list(set(all_findings))
diff = initial_size - len(findings)
if diff:
self._logger.warning(f'Missmatch of finding numbers, there seems to be an overlap of {diff}')
self._logger.warning(
f'Missmatch of finding numbers, there seems to be an overlap of {diff}')
return findings

def get_findings_by_matching_rule(self, rule: Rule) -> List[Finding]:
Expand Down Expand Up @@ -823,9 +851,11 @@ def _validate_rule_in_findings(self, findings: List[Finding]):
NoRuleFindings if strict mode is enabled and any findings do not have matching rules.
"""
no_rule_matches = [finding.id for finding in findings if not finding.matched_rule]
no_rule_matches = [
finding.id for finding in findings if not finding.matched_rule]
if no_rule_matches:
message = f'Findings with the following ids "{no_rule_matches}" do not have matching rules'
message = f'Findings with the following ids "{
no_rule_matches}" do not have matching rules'
if self._strict_mode:
raise NoRuleFindings(message)
self._logger.warning(message)
Expand All @@ -844,7 +874,8 @@ def _get_suppressing_payload(self, findings: List[Finding]):
A generator with suppressing payloads per common note chunked at MAX_SUPPRESSION_PAYLOAD_SIZE
"""
findings = findings if isinstance(findings, (list, tuple, set)) else [findings]
findings = findings if isinstance(
findings, (list, tuple, set)) else [findings]
findings = self._validate_rule_in_findings(findings)
rule_findings_mapping = defaultdict(list)
for finding in findings:
Expand All @@ -871,7 +902,8 @@ def _get_unsuppressing_payload(self, findings: List[Finding]):
A generator with unsuppressing payloads chunked at MAX_SUPPRESSION_PAYLOAD_SIZE
"""
findings = findings if isinstance(findings, (list, tuple, set)) else [findings]
findings = findings if isinstance(
findings, (list, tuple, set)) else [findings]
for chunk in FindingsManager._chunk([{'Id': finding.id,
'ProductArn': finding.product_arn}
for finding in findings], MAX_SUPPRESSION_PAYLOAD_SIZE):
Expand All @@ -892,15 +924,19 @@ def _workflow_state_change_on_findings(self, findings: List[Finding], suppress=T
message_state = 'suppression' if suppress else 'unsuppression'
method = self._get_suppressing_payload if suppress else self._get_unsuppressing_payload
security_hub = self._get_security_hub_client(self.aws_region)
return all((result for result in self._batch_apply_payloads(security_hub,
successes, payloads = zip(*(result for result in self._batch_apply_payloads(security_hub,
method(findings), # noqa
message_state)))
success = all(successes)
return (success, list(payloads))

def _batch_apply_payloads(self, security_hub, payloads, message_state):
for payload in payloads:
self._logger.debug(f'Sending payload {payload} for {message_state} to Security Hub.')
self._logger.debug(f'Sending payload {payload} for {
message_state} to Security Hub.')
if os.environ.get('FINDINGS_MANAGER_DRY_RUN_MODE'):
self._logger.debug(f'Dry run mode is on, skipping the actual {message_state}.')
self._logger.debug(
f'Dry run mode is on, skipping the actual {message_state}.')
continue
yield self._batch_update_findings(security_hub, payload)

Expand Down Expand Up @@ -935,7 +971,9 @@ def _batch_update_findings(self, security_hub, payload):
security_hub: Security hub client
payload: The payload to send to the service
Returns: True on success False otherwise
Returns:
tuple: A tuple containing a boolean status and the payload.
The status is True on success and False otherwise.
Raises:
FailedToBatchUpdate: if strict mode is set and there are failures to update.
Expand All @@ -951,8 +989,9 @@ def _batch_update_findings(self, security_hub, payload):
for fail in failed:
id_ = fail.get('FindingIdentifier', '').get('Id')
error = fail.get('ErrorMessage')
self._logger.error(f'Failed to update finding with ID: "{id_}" with error: "{error}"')
return status
self._logger.error(f'Failed to update finding with ID: "{
id_}" with error: "{error}"')
return (status, payload)

def validate_finding_on_matching_rules(self, finding_data: Dict):
"""Validates that the provided data is correct data for a finding.
Expand Down Expand Up @@ -982,12 +1021,14 @@ def _construct_findings_on_matching_rules(self, finding_data: Union[List[Dict],
if isinstance(finding_data, dict):
finding_data = [finding_data]
if self._strict_mode:
findings = [self.validate_finding_on_matching_rules(payload) for payload in finding_data]
findings = [self.validate_finding_on_matching_rules(
payload) for payload in finding_data]
else:
findings = []
for payload in finding_data:
try:
findings.append(self.validate_finding_on_matching_rules(payload))
findings.append(
self.validate_finding_on_matching_rules(payload))
except InvalidFindingData:
self._logger.error(f'Data {payload} seems to be invalid.')
return [finding for finding in findings if finding]
Expand All @@ -1002,7 +1043,8 @@ def suppress_finding_on_matching_rules(self, finding_data: Dict):
finding_data: The data of a finding as provided by Security Hub.
Returns:
True on success False otherwise.
tuple: A tuple containing a boolean status and the payload.
The status is True on success and False otherwise.
Raises:
InvalidFindingData: If the data is not valid finding data.
Expand All @@ -1020,13 +1062,15 @@ def suppress_findings_on_matching_rules(self, finding_data: Union[List[Dict], Di
finding_data: The data of a finding as provided by Security Hub.
Returns:
True on success False otherwise.
tuple: A tuple containing a boolean status and the payload.
The status is True on success and False otherwise.
Raises:
InvalidFindingData: If any data is not valid finding data.
"""
matching_findings = self._construct_findings_on_matching_rules(finding_data)
matching_findings = self._construct_findings_on_matching_rules(
finding_data)
return self._workflow_state_change_on_findings(matching_findings)

def get_unmanaged_suppressed_findings(self) -> List[Finding]:
Expand All @@ -1042,11 +1086,12 @@ def get_unmanaged_suppressed_findings(self) -> List[Finding]:
'Comparison': 'EQUALS'}]}
return self._get_findings(query)

def unsuppress_unmanaged_findings(self) -> bool:
def unsuppress_unmanaged_findings(self) -> tuple[bool, list]:
"""Unsuppresses findings that have not been suppressed by this library.
Returns:
True on full success, False otherwise.
tuple: A tuple containing a boolean status and the payload.
The status is True on success and False otherwise.
"""
return self._workflow_state_change_on_findings(self.get_unmanaged_suppressed_findings(), suppress=False)
Loading

0 comments on commit 4a59532

Please sign in to comment.