Skip to content

Commit

Permalink
wip: capabilities: use dataclasses to represent complicated return types
Browse files Browse the repository at this point in the history
  • Loading branch information
williballenthin committed Dec 10, 2024
1 parent 6d05d3c commit 8831a76
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 72 deletions.
90 changes: 55 additions & 35 deletions capa/capabilities/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import itertools
import collections
from typing import Any
from dataclasses import dataclass

import capa.perf
import capa.features.freeze as frz
Expand All @@ -26,13 +27,17 @@
SEQUENCE_SIZE = 5


@dataclass
class CallCapabilities:
features: FeatureSet
matches: MatchResults


def find_call_capabilities(
ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle, th: ThreadHandle, ch: CallHandle
) -> tuple[FeatureSet, MatchResults]:
) -> CallCapabilities:
"""
find matches for the given rules for the given call.
returns: tuple containing (features for call, match results for call)
"""
# all features found for the call.
features: FeatureSet = collections.defaultdict(set)
Expand All @@ -50,16 +55,22 @@ def find_call_capabilities(
for addr, _ in res:
capa.engine.index_rule_matches(features, rule, [addr])

return features, matches
return CallCapabilities(features, matches)


@dataclass
class ThreadCapabilities:
features: FeatureSet
thread_matches: MatchResults
sequence_matches: MatchResults
call_matches: MatchResults


def find_thread_capabilities(
ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle, th: ThreadHandle
) -> tuple[FeatureSet, MatchResults, MatchResults, MatchResults]:
) -> ThreadCapabilities:
"""
find matches for the given rules within the given thread.
returns: tuple containing (features for thread, match results for thread, match results for sequences, match results for calls)
"""
# all features found within this thread,
# includes features found within calls.
Expand All @@ -75,20 +86,20 @@ def find_thread_capabilities(
sequence: collections.deque[FeatureSet] = collections.deque(maxlen=SEQUENCE_SIZE)

for ch in extractor.get_calls(ph, th):
cfeatures, cmatches = find_call_capabilities(ruleset, extractor, ph, th, ch)
for feature, vas in cfeatures.items():
call_capabilities = find_call_capabilities(ruleset, extractor, ph, th, ch)
for feature, vas in call_capabilities.features.items():
features[feature].update(vas)

for rule_name, res in cmatches.items():
for rule_name, res in call_capabilities.matches.items():
call_matches[rule_name].extend(res)

sequence.append(cfeatures)
sfeatures: FeatureSet = collections.defaultdict(set)
sequence.append(call_capabilities.features)
sequence_features: FeatureSet = collections.defaultdict(set)
for call in sequence:
for feature, vas in call.items():
sfeatures[feature].update(vas)
sequence_features[feature].update(vas)

_, smatches = ruleset.match(Scope.SEQUENCE, sfeatures, ch.address)
_, smatches = ruleset.match(Scope.SEQUENCE, sequence_features, ch.address)
for rule_name, res in smatches.items():
sequence_matches[rule_name].extend(res)

Expand All @@ -103,16 +114,23 @@ def find_thread_capabilities(
for va, _ in res:
capa.engine.index_rule_matches(features, rule, [va])

return features, matches, sequence_matches, call_matches
return ThreadCapabilities(features, matches, sequence_matches, call_matches)


@dataclass
class ProcessCapabilities:
process_matches: MatchResults
thread_matches: MatchResults
sequence_matches: MatchResults
call_matches: MatchResults
feature_count: int


def find_process_capabilities(
ruleset: RuleSet, extractor: DynamicFeatureExtractor, ph: ProcessHandle
) -> tuple[MatchResults, MatchResults, MatchResults, MatchResults, int]:
) -> ProcessCapabilities:
"""
find matches for the given rules within the given process.
returns: tuple containing (match results for process, match results for threads, match results for calls, number of features)
"""
# all features found within this process,
# includes features found within threads (and calls).
Expand All @@ -131,24 +149,24 @@ def find_process_capabilities(
call_matches: MatchResults = collections.defaultdict(list)

for th in extractor.get_threads(ph):
features, tmatches, smatches, cmatches = find_thread_capabilities(ruleset, extractor, ph, th)
for feature, vas in features.items():
thread_capabilities = find_thread_capabilities(ruleset, extractor, ph, th)
for feature, vas in thread_capabilities.features.items():
process_features[feature].update(vas)

for rule_name, res in tmatches.items():
for rule_name, res in thread_capabilities.thread_matches.items():
thread_matches[rule_name].extend(res)

for rule_name, res in smatches.items():
for rule_name, res in thread_capabilities.sequence_matches.items():
sequence_matches[rule_name].extend(res)

for rule_name, res in cmatches.items():
for rule_name, res in thread_capabilities.call_matches.items():
call_matches[rule_name].extend(res)

for feature, va in itertools.chain(extractor.extract_process_features(ph), extractor.extract_global_features()):
process_features[feature].add(va)

_, process_matches = ruleset.match(Scope.PROCESS, process_features, ph.address)
return process_matches, thread_matches, sequence_matches, call_matches, len(process_features)
return ProcessCapabilities(process_matches, thread_matches, sequence_matches, call_matches, len(process_features))


def find_dynamic_capabilities(
Expand All @@ -170,21 +188,21 @@ def find_dynamic_capabilities(
) as pbar:
task = pbar.add_task("matching", total=n_processes, unit="processes")
for p in processes:
process_matches, thread_matches, sequence_matches, call_matches, feature_count = find_process_capabilities(
ruleset, extractor, p
)
process_capabilities = find_process_capabilities(ruleset, extractor, p)
feature_counts.processes += (
rdoc.ProcessFeatureCount(address=frz.Address.from_capa(p.address), count=feature_count),
rdoc.ProcessFeatureCount(
address=frz.Address.from_capa(p.address), count=process_capabilities.feature_count
),
)
logger.debug("analyzed %s and extracted %d features", p.address, feature_count)
logger.debug("analyzed %s and extracted %d features", p.address, process_capabilities.feature_count)

for rule_name, res in process_matches.items():
for rule_name, res in process_capabilities.process_matches.items():
all_process_matches[rule_name].extend(res)
for rule_name, res in thread_matches.items():
for rule_name, res in process_capabilities.thread_matches.items():
all_thread_matches[rule_name].extend(res)
for rule_name, res in sequence_matches.items():
for rule_name, res in process_capabilities.sequence_matches.items():
all_sequence_matches[rule_name].extend(res)
for rule_name, res in call_matches.items():
for rule_name, res in process_capabilities.call_matches.items():
all_call_matches[rule_name].extend(res)

pbar.advance(task)
Expand All @@ -199,8 +217,10 @@ def find_dynamic_capabilities(
rule = ruleset[rule_name]
capa.engine.index_rule_matches(process_and_lower_features, rule, locations)

all_file_matches, feature_count = find_file_capabilities(ruleset, extractor, process_and_lower_features)
feature_counts.file = feature_count
all_file_matches, process_capabilities.feature_count = find_file_capabilities(
ruleset, extractor, process_and_lower_features
)
feature_counts.file = process_capabilities.feature_count

matches = dict(
itertools.chain(
Expand Down
94 changes: 58 additions & 36 deletions capa/capabilities/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import itertools
import collections
from typing import Any
from dataclasses import dataclass

import capa.perf
import capa.helpers
Expand All @@ -24,13 +25,17 @@
logger = logging.getLogger(__name__)


@dataclass
class InstructionCapabilities:
features: FeatureSet
matches: MatchResults


def find_instruction_capabilities(
ruleset: RuleSet, extractor: StaticFeatureExtractor, f: FunctionHandle, bb: BBHandle, insn: InsnHandle
) -> tuple[FeatureSet, MatchResults]:
) -> InstructionCapabilities:
"""
find matches for the given rules for the given instruction.
returns: tuple containing (features for instruction, match results for instruction)
"""
# all features found for the instruction.
features: FeatureSet = collections.defaultdict(set)
Expand All @@ -48,16 +53,21 @@ def find_instruction_capabilities(
for addr, _ in res:
capa.engine.index_rule_matches(features, rule, [addr])

return features, matches
return InstructionCapabilities(features, matches)


@dataclass
class BasicBlockCapabilities:
features: FeatureSet
basic_block_matches: MatchResults
instruction_matches: MatchResults


def find_basic_block_capabilities(
ruleset: RuleSet, extractor: StaticFeatureExtractor, f: FunctionHandle, bb: BBHandle
) -> tuple[FeatureSet, MatchResults, MatchResults]:
) -> BasicBlockCapabilities:
"""
find matches for the given rules within the given basic block.
returns: tuple containing (features for basic block, match results for basic block, match results for instructions)
"""
# all features found within this basic block,
# includes features found within instructions.
Expand All @@ -68,11 +78,11 @@ def find_basic_block_capabilities(
insn_matches: MatchResults = collections.defaultdict(list)

for insn in extractor.get_instructions(f, bb):
ifeatures, imatches = find_instruction_capabilities(ruleset, extractor, f, bb, insn)
for feature, vas in ifeatures.items():
instruction_capabilities = find_instruction_capabilities(ruleset, extractor, f, bb, insn)
for feature, vas in instruction_capabilities.features.items():
features[feature].update(vas)

for rule_name, res in imatches.items():
for rule_name, res in instruction_capabilities.matches.items():
insn_matches[rule_name].extend(res)

for feature, va in itertools.chain(
Expand All @@ -88,16 +98,20 @@ def find_basic_block_capabilities(
for va, _ in res:
capa.engine.index_rule_matches(features, rule, [va])

return features, matches, insn_matches
return BasicBlockCapabilities(features, matches, insn_matches)


def find_code_capabilities(
ruleset: RuleSet, extractor: StaticFeatureExtractor, fh: FunctionHandle
) -> tuple[MatchResults, MatchResults, MatchResults, int]:
@dataclass
class CodeCapabilities:
function_matches: MatchResults
basic_block_matches: MatchResults
instruction_matches: MatchResults
feature_count: int


def find_code_capabilities(ruleset: RuleSet, extractor: StaticFeatureExtractor, fh: FunctionHandle) -> CodeCapabilities:
"""
find matches for the given rules within the given function.
returns: tuple containing (match results for function, match results for basic blocks, match results for instructions, number of features)
"""
# all features found within this function,
# includes features found within basic blocks (and instructions).
Expand All @@ -112,21 +126,21 @@ def find_code_capabilities(
insn_matches: MatchResults = collections.defaultdict(list)

for bb in extractor.get_basic_blocks(fh):
features, bmatches, imatches = find_basic_block_capabilities(ruleset, extractor, fh, bb)
for feature, vas in features.items():
basic_block_capabilities = find_basic_block_capabilities(ruleset, extractor, fh, bb)
for feature, vas in basic_block_capabilities.features.items():
function_features[feature].update(vas)

for rule_name, res in bmatches.items():
for rule_name, res in basic_block_capabilities.basic_block_matches.items():
bb_matches[rule_name].extend(res)

for rule_name, res in imatches.items():
for rule_name, res in basic_block_capabilities.instruction_matches.items():
insn_matches[rule_name].extend(res)

for feature, va in itertools.chain(extractor.extract_function_features(fh), extractor.extract_global_features()):
function_features[feature].add(va)

_, function_matches = ruleset.match(Scope.FUNCTION, function_features, fh.address)
return function_matches, bb_matches, insn_matches, len(function_features)
return CodeCapabilities(function_matches, bb_matches, insn_matches, len(function_features))


def find_static_capabilities(
Expand Down Expand Up @@ -165,30 +179,36 @@ def find_static_capabilities(
pbar.advance(task)
continue

function_matches, bb_matches, insn_matches, feature_count = find_code_capabilities(ruleset, extractor, f)
code_capabilities = find_code_capabilities(ruleset, extractor, f)
feature_counts.functions += (
rdoc.FunctionFeatureCount(address=frz.Address.from_capa(f.address), count=feature_count),
rdoc.FunctionFeatureCount(
address=frz.Address.from_capa(f.address), count=code_capabilities.feature_count
),
)
t1 = time.time()

match_count = 0
for name, matches_ in itertools.chain(function_matches.items(), bb_matches.items(), insn_matches.items()):
for name, matches_ in itertools.chain(
code_capabilities.function_matches.items(),
code_capabilities.basic_block_matches.items(),
code_capabilities.instruction_matches.items(),
):
if not ruleset.rules[name].is_subscope_rule():
match_count += len(matches_)

logger.debug(
"analyzed function 0x%x and extracted %d features, %d matches in %0.02fs",
f.address,
feature_count,
match_count,
t1 - t0,
)
# logger.debug(
# "analyzed function 0x%x and extracted %d features, %d matches in %0.02fs",
# f.address,
# feature_count,
# match_count,
# t1 - t0,
# )

for rule_name, res in function_matches.items():
for rule_name, res in code_capabilities.function_matches.items():
all_function_matches[rule_name].extend(res)
for rule_name, res in bb_matches.items():
for rule_name, res in code_capabilities.basic_block_matches.items():
all_bb_matches[rule_name].extend(res)
for rule_name, res in insn_matches.items():
for rule_name, res in code_capabilities.instruction_matches.items():
all_insn_matches[rule_name].extend(res)

pbar.advance(task)
Expand All @@ -203,8 +223,10 @@ def find_static_capabilities(
rule = ruleset[rule_name]
capa.engine.index_rule_matches(function_and_lower_features, rule, locations)

all_file_matches, feature_count = find_file_capabilities(ruleset, extractor, function_and_lower_features)
feature_counts.file = feature_count
all_file_matches, code_capabilities.feature_count = find_file_capabilities(
ruleset, extractor, function_and_lower_features
)
feature_counts.file = code_capabilities.feature_count

matches: MatchResults = dict(
itertools.chain(
Expand Down
1 change: 0 additions & 1 deletion tests/test_dynamic_sequence_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,3 @@ def test_dynamic_sequence_multiple_sequences_overlapping_single_event():
matches, features = capa.capabilities.dynamic.find_dynamic_capabilities(ruleset, extractor, disable_progress=True)
assert r.name in matches
assert [11, 12, 13, 14, 15] == list(get_call_ids(matches[r.name]))

0 comments on commit 8831a76

Please sign in to comment.