From 8831a760b3f04aafa9e27d7fd274d420b3b7f8fb Mon Sep 17 00:00:00 2001 From: Willi Ballenthin Date: Tue, 10 Dec 2024 15:58:32 +0000 Subject: [PATCH] wip: capabilities: use dataclasses to represent complicated return types --- capa/capabilities/dynamic.py | 90 +++++++++++++++----------- capa/capabilities/static.py | 94 +++++++++++++++++----------- tests/test_dynamic_sequence_scope.py | 1 - 3 files changed, 113 insertions(+), 72 deletions(-) diff --git a/capa/capabilities/dynamic.py b/capa/capabilities/dynamic.py index c280a888b..b80be6449 100644 --- a/capa/capabilities/dynamic.py +++ b/capa/capabilities/dynamic.py @@ -10,6 +10,7 @@ import itertools import collections from typing import Any +from dataclasses import dataclass import capa.perf import capa.features.freeze as frz @@ -26,13 +27,17 @@ SEQUENCE_SIZE = 5 +@dataclass +class CallCapabilities: + features: FeatureSet + matches: MatchResults + + def find_call_capabilities( ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle, th: ThreadHandle, ch: CallHandle -) -> tuple[FeatureSet, MatchResults]: +) -> CallCapabilities: """ find matches for the given rules for the given call. - - returns: tuple containing (features for call, match results for call) """ # all features found for the call. features: FeatureSet = collections.defaultdict(set) @@ -50,16 +55,22 @@ def find_call_capabilities( for addr, _ in res: capa.engine.index_rule_matches(features, rule, [addr]) - return features, matches + return CallCapabilities(features, matches) + + +@dataclass +class ThreadCapabilities: + features: FeatureSet + thread_matches: MatchResults + sequence_matches: MatchResults + call_matches: MatchResults def find_thread_capabilities( ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle, th: ThreadHandle -) -> tuple[FeatureSet, MatchResults, MatchResults, MatchResults]: +) -> ThreadCapabilities: """ find matches for the given rules within the given thread. - - returns: tuple containing (features for thread, match results for thread, match results for sequences, match results for calls) """ # all features found within this thread, # includes features found within calls. @@ -75,20 +86,20 @@ def find_thread_capabilities( sequence: collections.deque[FeatureSet] = collections.deque(maxlen=SEQUENCE_SIZE) for ch in extractor.get_calls(ph, th): - cfeatures, cmatches = find_call_capabilities(ruleset, extractor, ph, th, ch) - for feature, vas in cfeatures.items(): + call_capabilities = find_call_capabilities(ruleset, extractor, ph, th, ch) + for feature, vas in call_capabilities.features.items(): features[feature].update(vas) - for rule_name, res in cmatches.items(): + for rule_name, res in call_capabilities.matches.items(): call_matches[rule_name].extend(res) - sequence.append(cfeatures) - sfeatures: FeatureSet = collections.defaultdict(set) + sequence.append(call_capabilities.features) + sequence_features: FeatureSet = collections.defaultdict(set) for call in sequence: for feature, vas in call.items(): - sfeatures[feature].update(vas) + sequence_features[feature].update(vas) - _, smatches = ruleset.match(Scope.SEQUENCE, sfeatures, ch.address) + _, smatches = ruleset.match(Scope.SEQUENCE, sequence_features, ch.address) for rule_name, res in smatches.items(): sequence_matches[rule_name].extend(res) @@ -103,16 +114,23 @@ def find_thread_capabilities( for va, _ in res: capa.engine.index_rule_matches(features, rule, [va]) - return features, matches, sequence_matches, call_matches + return ThreadCapabilities(features, matches, sequence_matches, call_matches) + + +@dataclass +class ProcessCapabilities: + process_matches: MatchResults + thread_matches: MatchResults + sequence_matches: MatchResults + call_matches: MatchResults + feature_count: int def find_process_capabilities( ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle -) -> tuple[MatchResults, MatchResults, MatchResults, MatchResults, int]: +) -> ProcessCapabilities: """ find matches for the given rules within the given process. - - returns: tuple containing (match results for process, match results for threads, match results for calls, number of features) """ # all features found within this process, # includes features found within threads (and calls). @@ -131,24 +149,24 @@ def find_process_capabilities( call_matches: MatchResults = collections.defaultdict(list) for th in extractor.get_threads(ph): - features, tmatches, smatches, cmatches = find_thread_capabilities(ruleset, extractor, ph, th) - for feature, vas in features.items(): + thread_capabilities = find_thread_capabilities(ruleset, extractor, ph, th) + for feature, vas in thread_capabilities.features.items(): process_features[feature].update(vas) - for rule_name, res in tmatches.items(): + for rule_name, res in thread_capabilities.thread_matches.items(): thread_matches[rule_name].extend(res) - for rule_name, res in smatches.items(): + for rule_name, res in thread_capabilities.sequence_matches.items(): sequence_matches[rule_name].extend(res) - for rule_name, res in cmatches.items(): + for rule_name, res in thread_capabilities.call_matches.items(): call_matches[rule_name].extend(res) for feature, va in itertools.chain(extractor.extract_process_features(ph), extractor.extract_global_features()): process_features[feature].add(va) _, process_matches = ruleset.match(Scope.PROCESS, process_features, ph.address) - return process_matches, thread_matches, sequence_matches, call_matches, len(process_features) + return ProcessCapabilities(process_matches, thread_matches, sequence_matches, call_matches, len(process_features)) def find_dynamic_capabilities( @@ -170,21 +188,21 @@ def find_dynamic_capabilities( ) as pbar: task = pbar.add_task("matching", total=n_processes, unit="processes") for p in processes: - process_matches, thread_matches, sequence_matches, call_matches, feature_count = find_process_capabilities( - ruleset, extractor, p - ) + process_capabilities = find_process_capabilities(ruleset, extractor, p) feature_counts.processes += ( - rdoc.ProcessFeatureCount(address=frz.Address.from_capa(p.address), count=feature_count), + rdoc.ProcessFeatureCount( + address=frz.Address.from_capa(p.address), count=process_capabilities.feature_count + ), ) - logger.debug("analyzed %s and extracted %d features", p.address, feature_count) + logger.debug("analyzed %s and extracted %d features", p.address, process_capabilities.feature_count) - for rule_name, res in process_matches.items(): + for rule_name, res in process_capabilities.process_matches.items(): all_process_matches[rule_name].extend(res) - for rule_name, res in thread_matches.items(): + for rule_name, res in process_capabilities.thread_matches.items(): all_thread_matches[rule_name].extend(res) - for rule_name, res in sequence_matches.items(): + for rule_name, res in process_capabilities.sequence_matches.items(): all_sequence_matches[rule_name].extend(res) - for rule_name, res in call_matches.items(): + for rule_name, res in process_capabilities.call_matches.items(): all_call_matches[rule_name].extend(res) pbar.advance(task) @@ -199,8 +217,10 @@ def find_dynamic_capabilities( rule = ruleset[rule_name] capa.engine.index_rule_matches(process_and_lower_features, rule, locations) - all_file_matches, feature_count = find_file_capabilities(ruleset, extractor, process_and_lower_features) - feature_counts.file = feature_count + all_file_matches, process_capabilities.feature_count = find_file_capabilities( + ruleset, extractor, process_and_lower_features + ) + feature_counts.file = process_capabilities.feature_count matches = dict( itertools.chain( diff --git a/capa/capabilities/static.py b/capa/capabilities/static.py index df8cd7e78..d2f2c3db9 100644 --- a/capa/capabilities/static.py +++ b/capa/capabilities/static.py @@ -11,6 +11,7 @@ import itertools import collections from typing import Any +from dataclasses import dataclass import capa.perf import capa.helpers @@ -24,13 +25,17 @@ logger = logging.getLogger(__name__) +@dataclass +class InstructionCapabilities: + features: FeatureSet + matches: MatchResults + + def find_instruction_capabilities( ruleset: RuleSet, extractor: StaticFeatureExtractor, f: FunctionHandle, bb: BBHandle, insn: InsnHandle -) -> tuple[FeatureSet, MatchResults]: +) -> InstructionCapabilities: """ find matches for the given rules for the given instruction. - - returns: tuple containing (features for instruction, match results for instruction) """ # all features found for the instruction. features: FeatureSet = collections.defaultdict(set) @@ -48,16 +53,21 @@ def find_instruction_capabilities( for addr, _ in res: capa.engine.index_rule_matches(features, rule, [addr]) - return features, matches + return InstructionCapabilities(features, matches) + + +@dataclass +class BasicBlockCapabilities: + features: FeatureSet + basic_block_matches: MatchResults + instruction_matches: MatchResults def find_basic_block_capabilities( ruleset: RuleSet, extractor: StaticFeatureExtractor, f: FunctionHandle, bb: BBHandle -) -> tuple[FeatureSet, MatchResults, MatchResults]: +) -> BasicBlockCapabilities: """ find matches for the given rules within the given basic block. - - returns: tuple containing (features for basic block, match results for basic block, match results for instructions) """ # all features found within this basic block, # includes features found within instructions. @@ -68,11 +78,11 @@ def find_basic_block_capabilities( insn_matches: MatchResults = collections.defaultdict(list) for insn in extractor.get_instructions(f, bb): - ifeatures, imatches = find_instruction_capabilities(ruleset, extractor, f, bb, insn) - for feature, vas in ifeatures.items(): + instruction_capabilities = find_instruction_capabilities(ruleset, extractor, f, bb, insn) + for feature, vas in instruction_capabilities.features.items(): features[feature].update(vas) - for rule_name, res in imatches.items(): + for rule_name, res in instruction_capabilities.matches.items(): insn_matches[rule_name].extend(res) for feature, va in itertools.chain( @@ -88,16 +98,20 @@ def find_basic_block_capabilities( for va, _ in res: capa.engine.index_rule_matches(features, rule, [va]) - return features, matches, insn_matches + return BasicBlockCapabilities(features, matches, insn_matches) -def find_code_capabilities( - ruleset: RuleSet, extractor: StaticFeatureExtractor, fh: FunctionHandle -) -> tuple[MatchResults, MatchResults, MatchResults, int]: +@dataclass +class CodeCapabilities: + function_matches: MatchResults + basic_block_matches: MatchResults + instruction_matches: MatchResults + feature_count: int + + +def find_code_capabilities(ruleset: RuleSet, extractor: StaticFeatureExtractor, fh: FunctionHandle) -> CodeCapabilities: """ find matches for the given rules within the given function. - - returns: tuple containing (match results for function, match results for basic blocks, match results for instructions, number of features) """ # all features found within this function, # includes features found within basic blocks (and instructions). @@ -112,21 +126,21 @@ def find_code_capabilities( insn_matches: MatchResults = collections.defaultdict(list) for bb in extractor.get_basic_blocks(fh): - features, bmatches, imatches = find_basic_block_capabilities(ruleset, extractor, fh, bb) - for feature, vas in features.items(): + basic_block_capabilities = find_basic_block_capabilities(ruleset, extractor, fh, bb) + for feature, vas in basic_block_capabilities.features.items(): function_features[feature].update(vas) - for rule_name, res in bmatches.items(): + for rule_name, res in basic_block_capabilities.basic_block_matches.items(): bb_matches[rule_name].extend(res) - for rule_name, res in imatches.items(): + for rule_name, res in basic_block_capabilities.instruction_matches.items(): insn_matches[rule_name].extend(res) for feature, va in itertools.chain(extractor.extract_function_features(fh), extractor.extract_global_features()): function_features[feature].add(va) _, function_matches = ruleset.match(Scope.FUNCTION, function_features, fh.address) - return function_matches, bb_matches, insn_matches, len(function_features) + return CodeCapabilities(function_matches, bb_matches, insn_matches, len(function_features)) def find_static_capabilities( @@ -165,30 +179,36 @@ def find_static_capabilities( pbar.advance(task) continue - function_matches, bb_matches, insn_matches, feature_count = find_code_capabilities(ruleset, extractor, f) + code_capabilities = find_code_capabilities(ruleset, extractor, f) feature_counts.functions += ( - rdoc.FunctionFeatureCount(address=frz.Address.from_capa(f.address), count=feature_count), + rdoc.FunctionFeatureCount( + address=frz.Address.from_capa(f.address), count=code_capabilities.feature_count + ), ) t1 = time.time() match_count = 0 - for name, matches_ in itertools.chain(function_matches.items(), bb_matches.items(), insn_matches.items()): + for name, matches_ in itertools.chain( + code_capabilities.function_matches.items(), + code_capabilities.basic_block_matches.items(), + code_capabilities.instruction_matches.items(), + ): if not ruleset.rules[name].is_subscope_rule(): match_count += len(matches_) - logger.debug( - "analyzed function 0x%x and extracted %d features, %d matches in %0.02fs", - f.address, - feature_count, - match_count, - t1 - t0, - ) + # logger.debug( + # "analyzed function 0x%x and extracted %d features, %d matches in %0.02fs", + # f.address, + # feature_count, + # match_count, + # t1 - t0, + # ) - for rule_name, res in function_matches.items(): + for rule_name, res in code_capabilities.function_matches.items(): all_function_matches[rule_name].extend(res) - for rule_name, res in bb_matches.items(): + for rule_name, res in code_capabilities.basic_block_matches.items(): all_bb_matches[rule_name].extend(res) - for rule_name, res in insn_matches.items(): + for rule_name, res in code_capabilities.instruction_matches.items(): all_insn_matches[rule_name].extend(res) pbar.advance(task) @@ -203,8 +223,10 @@ def find_static_capabilities( rule = ruleset[rule_name] capa.engine.index_rule_matches(function_and_lower_features, rule, locations) - all_file_matches, feature_count = find_file_capabilities(ruleset, extractor, function_and_lower_features) - feature_counts.file = feature_count + all_file_matches, code_capabilities.feature_count = find_file_capabilities( + ruleset, extractor, function_and_lower_features + ) + feature_counts.file = code_capabilities.feature_count matches: MatchResults = dict( itertools.chain( diff --git a/tests/test_dynamic_sequence_scope.py b/tests/test_dynamic_sequence_scope.py index 810dc5b34..55fbd050b 100644 --- a/tests/test_dynamic_sequence_scope.py +++ b/tests/test_dynamic_sequence_scope.py @@ -253,4 +253,3 @@ def test_dynamic_sequence_multiple_sequences_overlapping_single_event(): matches, features = capa.capabilities.dynamic.find_dynamic_capabilities(ruleset, extractor, disable_progress=True) assert r.name in matches assert [11, 12, 13, 14, 15] == list(get_call_ids(matches[r.name])) -