From e9f153bdc7ef79aa98a9ebd4edb60410b702fb3d Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 3 Dec 2021 14:22:34 -0500 Subject: [PATCH 01/18] Bugfix: Expand the entire set before checking for existence --- polyfile/iterators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/polyfile/iterators.py b/polyfile/iterators.py index 0bee3f68..bd77bded 100644 --- a/polyfile/iterators.py +++ b/polyfile/iterators.py @@ -66,4 +66,5 @@ def __init__(self, source: Iterable[T]): super().__init__(unique(iter(source), elements=self._set)) def __contains__(self, x: object) -> bool: + self._complete() return x in self._set From 7378fa35dead86f9971afe1083bce4381f0edc4f Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 3 Dec 2021 14:23:17 -0500 Subject: [PATCH 02/18] Optimization: Only expand the entire set if the match doesn't exist yet --- polyfile/iterators.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/polyfile/iterators.py b/polyfile/iterators.py index bd77bded..9dac45e0 100644 --- a/polyfile/iterators.py +++ b/polyfile/iterators.py @@ -66,5 +66,7 @@ def __init__(self, source: Iterable[T]): super().__init__(unique(iter(source), elements=self._set)) def __contains__(self, x: object) -> bool: + if x in self._set: + return True self._complete() return x in self._set From e0e04a729696783827707311432b8ad6832a7549 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 3 Dec 2021 15:48:49 -0500 Subject: [PATCH 03/18] Started writing a libmagic DSL debugger --- polyfile/__main__.py | 8 +- polyfile/magic.py | 13 +- polyfile/magic_debugger.py | 245 +++++++++++++++++++++++++++++++++++++ 3 files changed, 262 insertions(+), 4 deletions(-) create mode 100644 polyfile/magic_debugger.py diff --git a/polyfile/__main__.py b/polyfile/__main__.py index 41e80455..171be146 100644 --- a/polyfile/__main__.py +++ b/polyfile/__main__.py @@ -1,4 +1,5 @@ import argparse +from contextlib import ExitStack import base64 import hashlib import json @@ -14,6 +15,7 @@ from . import polyfile from .fileutils import PathOrStdin from .magic import MagicMatcher, MatchContext +from .magic_debugger import Debugger from .polyfile import __version__ @@ -54,6 +56,8 @@ def main(argv=None): help='stop scanning after having found this many matches') parser.add_argument('--debug', '-d', action='store_true', help='print debug information') parser.add_argument('--trace', '-dd', action='store_true', help='print extra verbose debug information') + parser.add_argument('--debugger', '-db', action='store_true', help='drop into an interactive debugger for libmagic ' + 'file definition matching') parser.add_argument('--quiet', '-q', action='store_true', help='suppress all log output (overrides --debug)') parser.add_argument('--version', '-v', action='store_true', help='print PolyFile\'s version information to STDERR') parser.add_argument('-dumpversion', action='store_true', @@ -110,7 +114,9 @@ def main(argv=None): exit(1) return # this is here because linters are dumb and will complain about the next line without it - with path_or_stdin as file_path: + with path_or_stdin as file_path, ExitStack() as stack: + if args.debugger: + stack.enter_context(Debugger()) matches = [] try: if args.only_match_mime: diff --git a/polyfile/magic.py b/polyfile/magic.py index f00ed8d8..2fe5775f 100644 --- a/polyfile/magic.py +++ b/polyfile/magic.py @@ -21,7 +21,7 @@ import sys from time import gmtime, localtime, strftime from typing import ( - Any, BinaryIO, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union + Any, BinaryIO, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union ) from uuid import UUID @@ -166,7 +166,7 @@ def __repr__(self): def __str__(self): if self.test.message is not None: # TODO: Fix pasting our value in - return self.test.message + return str(self.test.message) #if self.value is not None and "%" in self.test.message: # return self.test.message % (self.value,) #else: @@ -543,6 +543,9 @@ def __str__(self): return f"${{x?{self.true_value}:{self.false_value}}}" +TEST_TYPES: Set[Type["MagicTest"]] = set() + + class MagicTest(ABC): def __init__( self, @@ -589,6 +592,10 @@ def __init__( self.mime = mime self.source_info: Optional[SourceInfo] = None + def __init_subclass__(cls, **kwargs): + TEST_TYPES.add(cls) + return super().__init_subclass__(**kwargs) + @property def parent(self) -> Optional["MagicTest"]: return self._parent @@ -718,7 +725,7 @@ def match(self, to_match: Union[bytes, BinaryIO, str, Path, MatchContext]) -> It def __str__(self): if self.source_info is not None and self.source_info.original_line is not None: - s = f"{self.source_info.path.name}:{self.source_info.line} {self.source_info.original_line}" + s = f"{self.source_info.path.name}:{self.source_info.line} {self.source_info.original_line.strip()}" else: s = f"{'>' * self.level}{self.offset!s}\t{self.message}" if self.mime is not None: diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py new file mode 100644 index 00000000..97f1f728 --- /dev/null +++ b/polyfile/magic_debugger.py @@ -0,0 +1,245 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Optional, Type, TypeVar + +from .magic import MagicTest, TestResult, TEST_TYPES + + +B = TypeVar("B", bound="Breakpoint") + + +BREAKPOINT_TYPES: List[Type["Breakpoint"]] = [] + + +class Breakpoint(ABC): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + BREAKPOINT_TYPES.append(cls) + + @abstractmethod + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + raise NotImplementedError() + + @classmethod + @abstractmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + raise NotImplementedError() + + @classmethod + @abstractmethod + def usage(cls) -> str: + raise NotImplementedError() + + @abstractmethod + def __str__(self): + raise NotImplementedError() + + +class MimeBreakpoint(Breakpoint): + def __init__(self, mimetype: str): + self.mimetype: str = mimetype + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + return self.mimetype in test.mimetypes() + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + if command.lower().startswith("mime:"): + return MimeBreakpoint(command[len("mime:"):]) + return None + + @classmethod + def usage(cls) -> str: + return "`b MIME:MIMETYPE` to break when a test is capable of matching that mimetype. For example, " \ + "\"b MIME:application/pdf\"." + + def __str__(self): + return f"Breakpoint: Matching for MIME {self.mimetype}" + + +class InstrumentedTest: + def __init__(self, test: Type[MagicTest], debugger: "Debugger"): + self.test: Type[MagicTest] = test + self.debugger: Debugger = debugger + if "test" in test.__dict__: + self.original_test: Optional[Callable[[...], Optional[TestResult]]] = test.test + + def wrapper(test_instance, *args, **kwargs) -> Optional[TestResult]: + # if self.original_test is None: + # # this is a NOOP + # return self.test.test(test_instance, *args, **kwargs) + return self.debugger.debug(self, test_instance, *args, **kwargs) + + test.test = wrapper + else: + self.original_test = None + + @property + def enabled(self) -> bool: + return self.original_test is not None + + def uninstrument(self): + if self.original_test is not None and self.test.test is self: + # we are still assigned to the test function, so reset it + self.test.test = self.original_test + self.original_test = None + + +def print_context(data: bytes, offset: int): + pass + + +class Debugger: + def __init__(self): + self.instrumented_tests: List[InstrumentedTest] = [] + self.breakpoints: List[Breakpoint] = [] + self._entries: int = 0 + self.single_stepping: bool = True + self.last_command: Optional[str] = None + + @property + def enabled(self) -> bool: + return any(t.enabled for t in self.instrumented_tests) + + @enabled.setter + def enabled(self, is_enabled: bool): + # Uninstrument any existing instrumentation: + for t in self.instrumented_tests: + t.uninstrument() + self.instrumented_tests = [] + if is_enabled: + # Instrument all of the MagicTest.test functions: + for test in TEST_TYPES: + if "test" in test.__dict__: + # this class actually implements the test() function + self.instrumented_tests.append(InstrumentedTest(test, self)) + + def __enter__(self) -> "Debugger": + self._entries += 1 + if self._entries == 1: + self.enabled = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._entries -= 1 + if self._entries == 0: + self.enabled = False + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + return self.single_stepping or any( + b.should_break(test, data, absolute_offset, parent_match) for b in self.breakpoints + ) + + def debug( + self, + instrumented_test: InstrumentedTest, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> Optional[TestResult]: + if "image/png" in test.mimetypes() or self.should_break(test, data, absolute_offset, parent_match): + self.repl(test, data, absolute_offset, parent_match) + if instrumented_test.original_test is None: + result = instrumented_test.test.test(test, data, absolute_offset, parent_match) + else: + result = instrumented_test.original_test(test, data, absolute_offset, parent_match) + return result + + def repl( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ): + for b in self.breakpoints: + if b.should_break(test, data, absolute_offset, parent_match): + print(str(b)) + print(test) + while True: + command = input("(polyfile) ") + if not command: + if self.last_command is None: + continue + command = self.last_command + command = command.lstrip() + space_index = command.find(" ") + if space_index > 0: + command, args = command[:space_index], command[space_index+1:].strip() + else: + args = "" + if "help".startswith(command): + print("TODO: Usage") + elif "continue".startswith(command) or "run".startswith(command): + self.single_stepping = False + self.last_command = command + return + elif "step".startswith(command) or "next".startswith(command): + self.single_stepping = True + self.last_command = command + return + elif "quit".startswith(command): + exit(0) + elif "breakpoint".startswith(command): + if args: + for b_type in BREAKPOINT_TYPES: + parsed = b_type.parse(args) + if parsed is not None: + print(str(parsed)) + self.breakpoints.append(parsed) + break + else: + print("Error: Invalid breakpoint pattern") + else: + if self.breakpoints: + for i, b in enumerate(self.breakpoints): + print(f"{i}: {b!s}") + else: + print("No breakpoints set.") + for b_type in BREAKPOINT_TYPES: + print(str(b_type.usage())) + elif "where".startswith(command) or "info stack".startswith(command) or "backtrace".startswith(command): + test_stack = [test] + while test_stack[-1].parent is not None: + test_stack.append(test_stack[-1].parent) + for i, t in enumerate(reversed(test_stack)): + if i == len(test_stack) - 1: + cmd = str(t).replace("\n", "\n ") + print(f"> {cmd}") + else: + print(f" {cmd}") + test_stack = list(reversed(test.children)) + descendants = [] + while test_stack: + descendant = test_stack.pop() + if descendant.can_match_mime: + descendants.append(descendant) + test_stack.extend(reversed(descendant.children)) + for t in descendants: + cmd = str(t).replace("\n", "\n ") + print(f" {cmd}") + print("") + print_context(data, absolute_offset) + else: + print(f"Undefined command: {command!r}. Try \"help\".") + self.last_command = None + continue + self.last_command = command From 386b48e37e658aa1cf0767741bb1bf34a377307b Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 3 Dec 2021 16:00:42 -0500 Subject: [PATCH 04/18] Print data context while stepping --- polyfile/magic_debugger.py | 42 ++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 97f1f728..79bd0dfd 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Type, TypeVar +from typing import Callable, List, Optional, Type, TypeVar, Union from .magic import MagicTest, TestResult, TEST_TYPES @@ -96,8 +96,35 @@ def uninstrument(self): self.original_test = None -def print_context(data: bytes, offset: int): - pass +def string_escape(data: Union[bytes, int]) -> str: + if not isinstance(data, int): + return "".join(string_escape(d) for d in data) + elif data == ord('\n'): + return "\\n" + elif data == ord('\t'): + return "\\t" + elif data == ord('\r'): + return "\\r" + elif data == 0: + return "\\0" + elif data == ord('\\'): + return "\\\\" + elif 32 <= data <= 126: + return chr(data) + else: + return f"\\x{data:02X}" + + +def print_context(data: bytes, offset: int, context_bytes: int = 32): + bytes_before = min(offset, context_bytes) + context_before = string_escape(data[:bytes_before]) + if 0 <= offset < len(data): + current_byte = string_escape(data[offset]) + else: + current_byte = "" + context_after = string_escape(data[offset + 1:offset + context_bytes]) + print(f"{context_before}{current_byte}{context_after}") + print(f"{' ' * len(context_before)}{'^' * len(current_byte)}{' ' * len(context_after)}") class Debugger: @@ -155,12 +182,17 @@ def debug( absolute_offset: int, parent_match: Optional[TestResult] ) -> Optional[TestResult]: - if "image/png" in test.mimetypes() or self.should_break(test, data, absolute_offset, parent_match): + if self.should_break(test, data, absolute_offset, parent_match): self.repl(test, data, absolute_offset, parent_match) if instrumented_test.original_test is None: result = instrumented_test.test.test(test, data, absolute_offset, parent_match) else: result = instrumented_test.original_test(test, data, absolute_offset, parent_match) + if self.single_stepping: + if result is None: + print("Test failed.\n") + else: + print("Test succeeded.\n") return result def repl( @@ -174,6 +206,8 @@ def repl( if b.should_break(test, data, absolute_offset, parent_match): print(str(b)) print(test) + print() + print_context(data, absolute_offset) while True: command = input("(polyfile) ") if not command: From 554ce06e4320b17cd6867eeab815839b39475702 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Fri, 3 Dec 2021 16:07:10 -0500 Subject: [PATCH 05/18] Adds the ability to delete breakpoints --- polyfile/magic_debugger.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 79bd0dfd..980243b1 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -209,7 +209,11 @@ def repl( print() print_context(data, absolute_offset) while True: - command = input("(polyfile) ") + try: + command = input("(polyfile) ") + except EOFError: + # the user pressed ^D to quit + exit(0) if not command: if self.last_command is None: continue @@ -232,6 +236,18 @@ def repl( return elif "quit".startswith(command): exit(0) + elif "delete".startswith(command): + if args: + try: + breakpoint_num = int(args) + except ValueError: + breakpoint_num = -1 + if not (0 <= breakpoint_num < len(self.breakpoints)): + print(f"Error: Invalid breakpoint \"{args}\"") + continue + b = self.breakpoints[breakpoint_num] + self.breakpoints = self.breakpoints[:breakpoint_num] + self.breakpoints[breakpoint_num + 1:] + print(f"Deleted {b!s}") elif "breakpoint".startswith(command): if args: for b_type in BREAKPOINT_TYPES: @@ -255,8 +271,8 @@ def repl( while test_stack[-1].parent is not None: test_stack.append(test_stack[-1].parent) for i, t in enumerate(reversed(test_stack)): + cmd = str(t).replace("\n", "\n ") if i == len(test_stack) - 1: - cmd = str(t).replace("\n", "\n ") print(f"> {cmd}") else: print(f" {cmd}") From b7830af2103efdc31ace2cb6efd460aee2eb0d1d Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Sun, 5 Dec 2021 14:55:00 -0500 Subject: [PATCH 06/18] Adds some ANSI color --- polyfile/magic_debugger.py | 114 ++++++++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 28 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 980243b1..0037dcd0 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -1,9 +1,25 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Type, TypeVar, Union +from enum import Enum +import sys +from typing import Any, Callable, List, Optional, Type, TypeVar, Union from .magic import MagicTest, TestResult, TEST_TYPES +class ANSIColor(Enum): + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + + def to_code(self) -> str: + return f"\u001b[{self.value}m" + + B = TypeVar("B", bound="Breakpoint") @@ -115,18 +131,6 @@ def string_escape(data: Union[bytes, int]) -> str: return f"\\x{data:02X}" -def print_context(data: bytes, offset: int, context_bytes: int = 32): - bytes_before = min(offset, context_bytes) - context_before = string_escape(data[:bytes_before]) - if 0 <= offset < len(data): - current_byte = string_escape(data[offset]) - else: - current_byte = "" - context_after = string_escape(data[offset + 1:offset + context_bytes]) - print(f"{context_before}{current_byte}{context_after}") - print(f"{' ' * len(context_before)}{'^' * len(current_byte)}{' ' * len(context_after)}") - - class Debugger: def __init__(self): self.instrumented_tests: List[InstrumentedTest] = [] @@ -174,6 +178,53 @@ def should_break( b.should_break(test, data, absolute_offset, parent_match) for b in self.breakpoints ) + def write_test(self, test: MagicTest): + if test.source_info is not None and test.source_info.original_line is not None: + self.write(f"{test.source_info.path.name}:{test.source_info.line} ", dim=True) + self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE) + else: + self.write(f"{'>' * test.level}{test.offset!s}\t") + self.write(test.message, color=ANSIColor.BLUE) + if test.mime is not None: + self.write("\n!:mime\t", dim=True) + self.write(test.mime, color=ANSIColor.BLUE) + for e in test.extensions: + self.write("\n!:ext\t", dim=True) + self.write(str(e), color=ANSIColor.BLUE) + + def write(self, message: Any, bold: bool = False, dim: bool = False, color: Optional[ANSIColor] = None): + if sys.stdout.isatty(): + if isinstance(message, MagicTest): + self.write_test(message) + return + prefixes: List[str] = [] + if bold and not dim: + prefixes.append("\u001b[1m") + elif dim and not bold: + prefixes.append("\u001b[2m") + if color is not None: + prefixes.append(color.to_code()) + if prefixes: + sys.stdout.write(f"{''.join(prefixes)}{message!s}\u001b[0m") + return + sys.stdout.write(str(message)) + + def print_context(self, data: bytes, offset: int, context_bytes: int = 32): + bytes_before = min(offset, context_bytes) + context_before = string_escape(data[:bytes_before]) + if 0 <= offset < len(data): + current_byte = string_escape(data[offset]) + else: + current_byte = "" + context_after = string_escape(data[offset + 1:offset + context_bytes]) + self.write(context_before) + self.write(current_byte, bold=True) + self.write(context_after) + self.write("\n") + self.write(f"{' ' * len(context_before)}") + self.write(f"{'^' * len(current_byte)}", bold=True) + self.write(f"{' ' * len(context_after)}\n") + def debug( self, instrumented_test: InstrumentedTest, @@ -190,9 +241,9 @@ def debug( result = instrumented_test.original_test(test, data, absolute_offset, parent_match) if self.single_stepping: if result is None: - print("Test failed.\n") + self.write("Test failed.\n\n", color=ANSIColor.RED) else: - print("Test succeeded.\n") + self.write("Test succeeded.\n\n", color=ANSIColor.GREEN) return result def repl( @@ -204,13 +255,16 @@ def repl( ): for b in self.breakpoints: if b.should_break(test, data, absolute_offset, parent_match): - print(str(b)) - print(test) - print() - print_context(data, absolute_offset) + self.write(b, color=ANSIColor.MAGENTA) + self.write("\n") + self.write(test) + self.write("\n\n") + self.print_context(data, absolute_offset) while True: try: - command = input("(polyfile) ") + self.write("(polyfile) ", bold=True) + sys.stdout.flush() + command = input() except EOFError: # the user pressed ^D to quit exit(0) @@ -247,25 +301,29 @@ def repl( continue b = self.breakpoints[breakpoint_num] self.breakpoints = self.breakpoints[:breakpoint_num] + self.breakpoints[breakpoint_num + 1:] - print(f"Deleted {b!s}") + self.write(f"Deleted {b!s}\n") elif "breakpoint".startswith(command): if args: for b_type in BREAKPOINT_TYPES: parsed = b_type.parse(args) if parsed is not None: - print(str(parsed)) + self.write(parsed, color=ANSIColor.MAGENTA) + self.write("\n") self.breakpoints.append(parsed) break else: - print("Error: Invalid breakpoint pattern") + self.write("Error: Invalid breakpoint pattern\n", color=ANSIColor.RED) else: if self.breakpoints: for i, b in enumerate(self.breakpoints): - print(f"{i}: {b!s}") + self.write(f"{i}:\t", dim=True) + self.write(b, color=ANSIColor.MAGENTA) + self.write("\n") else: - print("No breakpoints set.") + self.write("No breakpoints set.\n", color=ANSIColor.RED) for b_type in BREAKPOINT_TYPES: - print(str(b_type.usage())) + self.write(b_type.usage()) + self.write("\n") elif "where".startswith(command) or "info stack".startswith(command) or "backtrace".startswith(command): test_stack = [test] while test_stack[-1].parent is not None: @@ -287,9 +345,9 @@ def repl( cmd = str(t).replace("\n", "\n ") print(f" {cmd}") print("") - print_context(data, absolute_offset) + self.print_context(data, absolute_offset) else: - print(f"Undefined command: {command!r}. Try \"help\".") + self.write(f"Undefined command: {command!r}. Try \"help\".\n", color=ANSIColor.RED) self.last_command = None continue self.last_command = command From 9d23a38116c91d3e62beb7d535a3d22e6515b55a Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Sun, 5 Dec 2021 14:59:32 -0500 Subject: [PATCH 07/18] Suppress matching status before REPL --- polyfile/magic_debugger.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 0037dcd0..9a76adab 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -3,9 +3,13 @@ import sys from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from .logger import getStatusLogger from .magic import MagicTest, TestResult, TEST_TYPES +log = getStatusLogger("polyfile") + + class ANSIColor(Enum): BLACK = 30 RED = 31 @@ -253,6 +257,7 @@ def repl( absolute_offset: int, parent_match: Optional[TestResult] ): + log.clear_status() for b in self.breakpoints: if b.should_break(test, data, absolute_offset, parent_match): self.write(b, color=ANSIColor.MAGENTA) From e585c2ff0d0d259ac556ef8fbd14c14c9c07dbae Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Sun, 5 Dec 2021 15:04:56 -0500 Subject: [PATCH 08/18] Added an extension matching breakpoint --- polyfile/magic_debugger.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 9a76adab..a2064318 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -88,6 +88,34 @@ def __str__(self): return f"Breakpoint: Matching for MIME {self.mimetype}" +class ExtensionBreakpoint(Breakpoint): + def __init__(self, ext: str): + self.ext: str = ext + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + return self.ext in test.all_extensions() + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + if command.lower().startswith("ext:"): + return MimeBreakpoint(command[len("ext:"):]) + return None + + @classmethod + def usage(cls) -> str: + return "`b EXT:EXTENSION` to break when a test is capable of matching that extension. For example, " \ + "\"b EXT:pdf\"." + + def __str__(self): + return f"Breakpoint: Matching for extension {self.ext}" + + class InstrumentedTest: def __init__(self, test: Type[MagicTest], debugger: "Debugger"): self.test: Type[MagicTest] = test From b707f56631ca253b51e30efee497dd5d3f60c926 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Sun, 5 Dec 2021 15:06:24 -0500 Subject: [PATCH 09/18] Bugfix: return the correct type of breakpoint --- polyfile/magic_debugger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index a2064318..e9c16a9c 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -104,7 +104,7 @@ def should_break( @classmethod def parse(cls: Type[B], command: str) -> Optional[B]: if command.lower().startswith("ext:"): - return MimeBreakpoint(command[len("ext:"):]) + return ExtensionBreakpoint(command[len("ext:"):]) return None @classmethod From 203d9ad62c6de8a6edb770739ed7600a8223c069 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Sun, 5 Dec 2021 21:51:33 -0500 Subject: [PATCH 10/18] Adds breakpoints for line numbers --- polyfile/magic_debugger.py | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index e9c16a9c..cedb4f37 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -116,6 +116,49 @@ def __str__(self): return f"Breakpoint: Matching for extension {self.ext}" +class FileBreakpoint(Breakpoint): + def __init__(self, filename: str, line: int): + self.filename: str = filename + self.line: int = line + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + if test.source_info is None or test.source_info.line != self.line: + return False + if "/" in self.filename: + # it is a file path + return str(test.source_info.path) == self.filename + else: + # treat it like a filename + return test.source_info.path.name == self.filename + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + filename, *remainder = command.split(":") + if not remainder: + return None + try: + line = int("".join(remainder)) + except ValueError: + return None + if line <= 0: + return None + return FileBreakpoint(filename, line) + + @classmethod + def usage(cls) -> str: + return "`b FILENAME:LINE_NO` to break when the line of the given magic file is reached. For example, " \ + "\"b archive:525\"." + + def __str__(self): + return f"Breakpoint: {self.filename} line {self.line}" + + class InstrumentedTest: def __init__(self, test: Type[MagicTest], debugger: "Debugger"): self.test: Type[MagicTest] = test From f89df289d83baad5d47fbf50237aafa974f39b72 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Mon, 6 Dec 2021 10:48:34 -0500 Subject: [PATCH 11/18] Added wildcard support for MIME matching --- polyfile/magic_debugger.py | 9 ++++-- polyfile/wildcards.py | 62 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 polyfile/wildcards.py diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index cedb4f37..1eddd370 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -5,6 +5,7 @@ from .logger import getStatusLogger from .magic import MagicTest, TestResult, TEST_TYPES +from .wildcards import Wildcard log = getStatusLogger("polyfile") @@ -63,6 +64,7 @@ def __str__(self): class MimeBreakpoint(Breakpoint): def __init__(self, mimetype: str): self.mimetype: str = mimetype + self.pattern: Wildcard = Wildcard.parse(mimetype) def should_break( self, @@ -71,7 +73,7 @@ def should_break( absolute_offset: int, parent_match: Optional[TestResult] ) -> bool: - return self.mimetype in test.mimetypes() + return self.pattern.is_contained_in(test.mimetypes()) @classmethod def parse(cls: Type[B], command: str) -> Optional[B]: @@ -81,8 +83,9 @@ def parse(cls: Type[B], command: str) -> Optional[B]: @classmethod def usage(cls) -> str: - return "`b MIME:MIMETYPE` to break when a test is capable of matching that mimetype. For example, " \ - "\"b MIME:application/pdf\"." + return "`b MIME:MIMETYPE` to break when a test is capable of matching that mimetype. "\ + "The MIMETYPE can include the '*' and '?' wildcards. For example, " \ + "\"b MIME:application/pdf\" and \"b MIME:*pdf\"." def __str__(self): return f"Breakpoint: Matching for MIME {self.mimetype}" diff --git a/polyfile/wildcards.py b/polyfile/wildcards.py new file mode 100644 index 00000000..91f411e3 --- /dev/null +++ b/polyfile/wildcards.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod +from collections.abc import Container +import re +from typing import Iterable, List + + +class Wildcard(ABC): + @abstractmethod + def match(self, to_test: str) -> bool: + raise NotImplementedError() + + def is_contained_in(self, items: Iterable[str]) -> bool: + for item in items: + if self.match(item): + return True + return False + + @staticmethod + def parse(to_test: str) -> "Wildcard": + if "*" in to_test or "?" in to_test: + return SimpleWildcard(to_test) + else: + return ConstantMatch(to_test) + + +class ConstantMatch(Wildcard): + def __init__(self, to_match: str): + self.to_match: str = to_match + + def match(self, to_test: str) -> bool: + return self.to_match == to_test + + def is_contained_in(self, items: Iterable[str]) -> bool: + if isinstance(items, Container): + return self.to_match in items + return super().is_contained_in(items) + + +class SimpleWildcard(Wildcard): + def __init__(self, pattern: str): + self.raw_pattern: str = pattern + self.pattern = re.compile(self.escaped_pattern) + + @property + def escaped_pattern(self) -> str: + components: List[str] = [] + component: str = "" + for last_c, c in zip((None,) + tuple(self.raw_pattern), self.raw_pattern): + escaped = last_c is not None and last_c == "\\" + if not escaped and (c == "*" or c == "?"): + if component: + components.append(re.escape(component)) + component = "" + components.append(f".{c}") + else: + component = f"{component}{c}" + if component: + components.append(re.escape(component)) + return "".join(components) + + def match(self, to_test: str) -> bool: + return bool(self.pattern.match(to_test)) From ead7ef8b1a385736f2ee4dd56c2ab4e8a4b401a0 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 7 Dec 2021 09:57:56 -0500 Subject: [PATCH 12/18] Improved output and coloring for `where` --- polyfile/magic_debugger.py | 170 ++++++++++++++++++++++++------------- polyfile/polyfile.py | 4 + 2 files changed, 113 insertions(+), 61 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 1eddd370..befe2a67 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -3,6 +3,7 @@ import sys from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from .polyfile import __copyright__, __license__, __version__ from .logger import getStatusLogger from .magic import MagicTest, TestResult, TEST_TYPES from .wildcards import Wildcard @@ -209,13 +210,24 @@ def string_escape(data: Union[bytes, int]) -> str: return f"\\x{data:02X}" +class StepMode(Enum): + RUNNING = 0 + SINGLE_STEPPING = 1 + NEXT = 2 + + class Debugger: def __init__(self): self.instrumented_tests: List[InstrumentedTest] = [] self.breakpoints: List[Breakpoint] = [] self._entries: int = 0 - self.single_stepping: bool = True + self.step_mode: StepMode = StepMode.RUNNING self.last_command: Optional[str] = None + self.last_test: Optional[MagicTest] = None + self.last_parent_match: Optional[MagicTest] = None + self.data: bytes = b"" + self.last_offset: int = 0 + self.last_result: Optional[TestResult] = None @property def enabled(self) -> bool: @@ -233,6 +245,9 @@ def enabled(self, is_enabled: bool): if "test" in test.__dict__: # this class actually implements the test() function self.instrumented_tests.append(InstrumentedTest(test, self)) + self.write(f"PolyFile {__version__}\n", color=ANSIColor.MAGENTA, bold=True) + self.write(f"{__copyright__}\n{__license__}\n\nFor help, type \"help\".\n") + self.repl() def __enter__(self) -> "Debugger": self._entries += 1 @@ -245,18 +260,19 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self._entries == 0: self.enabled = False - def should_break( - self, - test: MagicTest, - data: bytes, - absolute_offset: int, - parent_match: Optional[TestResult] - ) -> bool: - return self.single_stepping or any( - b.should_break(test, data, absolute_offset, parent_match) for b in self.breakpoints + def should_break(self) -> bool: + return self.step_mode == StepMode.SINGLE_STEPPING or ( + self.step_mode == StepMode.NEXT and self.last_result is not None + ) or any( + b.should_break(self.last_test, self.data, self.last_offset, self.last_parent_match) + for b in self.breakpoints ) - def write_test(self, test: MagicTest): + def write_test(self, test: MagicTest, is_current_test: bool = False): + if is_current_test: + self.write("→ ", bold=True) + else: + self.write(" ") if test.source_info is not None and test.source_info.original_line is not None: self.write(f"{test.source_info.path.name}:{test.source_info.line} ", dim=True) self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE) @@ -264,11 +280,12 @@ def write_test(self, test: MagicTest): self.write(f"{'>' * test.level}{test.offset!s}\t") self.write(test.message, color=ANSIColor.BLUE) if test.mime is not None: - self.write("\n!:mime\t", dim=True) + self.write("\n !:mime ", dim=True) self.write(test.mime, color=ANSIColor.BLUE) for e in test.extensions: - self.write("\n!:ext\t", dim=True) + self.write("\n !:ext ", dim=True) self.write(str(e), color=ANSIColor.BLUE) + self.write("\n") def write(self, message: Any, bold: bool = False, dim: bool = False, color: Optional[ANSIColor] = None): if sys.stdout.isatty(): @@ -311,34 +328,58 @@ def debug( absolute_offset: int, parent_match: Optional[TestResult] ) -> Optional[TestResult]: - if self.should_break(test, data, absolute_offset, parent_match): - self.repl(test, data, absolute_offset, parent_match) + self.last_test = test + self.data = data + self.last_offset = absolute_offset + self.last_parent_match = parent_match if instrumented_test.original_test is None: - result = instrumented_test.test.test(test, data, absolute_offset, parent_match) + self.last_result = instrumented_test.test.test(test, data, absolute_offset, parent_match) else: - result = instrumented_test.original_test(test, data, absolute_offset, parent_match) - if self.single_stepping: - if result is None: - self.write("Test failed.\n\n", color=ANSIColor.RED) - else: - self.write("Test succeeded.\n\n", color=ANSIColor.GREEN) - return result - - def repl( - self, - test: MagicTest, - data: bytes, - absolute_offset: int, - parent_match: Optional[TestResult] - ): - log.clear_status() + self.last_result = instrumented_test.original_test(test, data, absolute_offset, parent_match) + if self.should_break(): + self.repl() + return self.last_result + + def print_where(self): + if self.last_test is None: + self.write("The first test has not yet been run.\n", color=ANSIColor.RED) + self.write("Use `step`, `next`, or `run` to start testing.\n") + return + wrote_breakpoints = False for b in self.breakpoints: - if b.should_break(test, data, absolute_offset, parent_match): + if b.should_break(self.last_test, self.data, self.last_offset, self.last_parent_match): self.write(b, color=ANSIColor.MAGENTA) self.write("\n") - self.write(test) - self.write("\n\n") - self.print_context(data, absolute_offset) + wrote_breakpoints = True + if wrote_breakpoints: + self.write("\n") + test_stack = [self.last_test] + while test_stack[-1].parent is not None: + test_stack.append(test_stack[-1].parent) + for i, t in enumerate(reversed(test_stack)): + if i == len(test_stack) - 1: + self.write_test(t, is_current_test=True) + else: + self.write_test(t) + test_stack = list(reversed(self.last_test.children)) + descendants = [] + while test_stack: + descendant = test_stack.pop() + if descendant.can_match_mime: + descendants.append(descendant) + test_stack.extend(reversed(descendant.children)) + for t in descendants: + self.write_test(t) + self.write("\n") + self.print_context(self.data, self.last_offset) + if self.last_result is None: + self.write("Test failed.\n", color=ANSIColor.RED) + else: + self.write("Test succeeded.\n", color=ANSIColor.GREEN) + + def repl(self): + log.clear_status() + self.print_where() while True: try: self.write("(polyfile) ", bold=True) @@ -358,13 +399,40 @@ def repl( else: args = "" if "help".startswith(command): - print("TODO: Usage") + usage = [ + ("help", "print this message"), + ("continue", "continue execution until the next breakpoint is hit"), + ("step", "step through a single magic test"), + ("next", "continue execution until the next test that matches"), + ("where", "print the context of the current magic test"), + ("breakpoint", "list the current breakpoints or add a new one"), + ("delete", "delete a breakpoint"), + ("quit", "exit the debugger"), + ] + aliases = { + "step": ("next",), + "where": ("info stack", "backtrace") + } + left_col_width = max(len(u[0]) for u in usage) + left_col_width = max(left_col_width, max(len(c) for a in aliases.values() for c in a)) + left_col_width += 3 + for command, msg in usage: + self.write(command, bold=True, color=ANSIColor.BLUE) + self.write(f" {'.' * (left_col_width - len(command) - 2)} ") + self.write(msg) + self.write("\n") + self.write("\nAliases:\n", dim=True) + elif "continue".startswith(command) or "run".startswith(command): - self.single_stepping = False + self.step_mode = StepMode.RUNNING self.last_command = command return - elif "step".startswith(command) or "next".startswith(command): - self.single_stepping = True + elif "step".startswith(command): + self.step_mode = StepMode.SINGLE_STEPPING + self.last_command = command + return + elif "next".startswith(command): + self.step_mode = StepMode.NEXT self.last_command = command return elif "quit".startswith(command): @@ -404,27 +472,7 @@ def repl( self.write(b_type.usage()) self.write("\n") elif "where".startswith(command) or "info stack".startswith(command) or "backtrace".startswith(command): - test_stack = [test] - while test_stack[-1].parent is not None: - test_stack.append(test_stack[-1].parent) - for i, t in enumerate(reversed(test_stack)): - cmd = str(t).replace("\n", "\n ") - if i == len(test_stack) - 1: - print(f"> {cmd}") - else: - print(f" {cmd}") - test_stack = list(reversed(test.children)) - descendants = [] - while test_stack: - descendant = test_stack.pop() - if descendant.can_match_mime: - descendants.append(descendant) - test_stack.extend(reversed(descendant.children)) - for t in descendants: - cmd = str(t).replace("\n", "\n ") - print(f" {cmd}") - print("") - self.print_context(data, absolute_offset) + self.print_where() else: self.write(f"Undefined command: {command!r}. Try \"help\".\n", color=ANSIColor.RED) self.last_command = None diff --git a/polyfile/polyfile.py b/polyfile/polyfile.py index 64c1137a..226cf69e 100644 --- a/polyfile/polyfile.py +++ b/polyfile/polyfile.py @@ -2,6 +2,7 @@ from json import dumps from pathlib import Path import pkg_resources +from time import localtime from typing import Any, Dict, IO, Iterator, List, Optional, Set, Tuple, Type, Union from .fileutils import FileStream @@ -9,6 +10,9 @@ from .magic import MagicMatcher, MatchContext __version__: str = pkg_resources.require("polyfile")[0].version +mod_year = localtime(Path(__file__).stat().st_mtime).tm_year +__copyright__: str = f"Copyright ©{mod_year} Trail of Bits" +__license__: str = "Apache License Version 2.0 https://www.apache.org/licenses/" CUSTOM_MATCHERS: Dict[str, Type["Match"]] = {} From 5af5d7b76b5bf9321d0133280421fa65010024f3 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 7 Dec 2021 10:05:51 -0500 Subject: [PATCH 13/18] Improved debugger usage --- polyfile/magic_debugger.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index befe2a67..c38fa699 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -379,7 +379,8 @@ def print_where(self): def repl(self): log.clear_status() - self.print_where() + if self.last_test is not None: + self.print_where() while True: try: self.write("(polyfile) ", bold=True) @@ -410,7 +411,6 @@ def repl(self): ("quit", "exit the debugger"), ] aliases = { - "step": ("next",), "where": ("info stack", "backtrace") } left_col_width = max(len(u[0]) for u in usage) @@ -420,8 +420,17 @@ def repl(self): self.write(command, bold=True, color=ANSIColor.BLUE) self.write(f" {'.' * (left_col_width - len(command) - 2)} ") self.write(msg) + if command in aliases: + self.write(" (aliases: ", dim=True) + alternatives = aliases[command] + for i, alt in enumerate(alternatives): + if i > 0 and len(alternatives) > 2: + self.write(", ", dim=True) + if i == len(alternatives) - 1 and len(alternatives) > 1: + self.write(" and ", dim=True) + self.write(alt, bold=True, color=ANSIColor.BLUE) + self.write(")", dim=True) self.write("\n") - self.write("\nAliases:\n", dim=True) elif "continue".startswith(command) or "run".startswith(command): self.step_mode = StepMode.RUNNING From d42f562eb1c1b5bb51aeca5fc44cffd010d76f10 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 7 Dec 2021 10:42:16 -0500 Subject: [PATCH 14/18] Include source code comments in the magic tests --- polyfile/magic.py | 31 ++++++++++++++++++++++++++++--- polyfile/magic_debugger.py | 17 ++++++++++++++--- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/polyfile/magic.py b/polyfile/magic.py index 2fe5775f..89e798a6 100644 --- a/polyfile/magic.py +++ b/polyfile/magic.py @@ -546,6 +546,15 @@ def __str__(self): TEST_TYPES: Set[Type["MagicTest"]] = set() +class Comment: + def __init__(self, message: str, source_info: Optional[SourceInfo] = None): + self.message: str = message + self.source_info: Optional[SourceInfo] = source_info + + def __str__(self): + return self.message + + class MagicTest(ABC): def __init__( self, @@ -553,7 +562,8 @@ def __init__( mime: Optional[Union[str, TernaryExecutableMessage]] = None, extensions: Iterable[str] = (), message: Union[str, Message] = "", - parent: Optional["MagicTest"] = None + parent: Optional["MagicTest"] = None, + comments: Iterable[Comment] = () ): self.offset: Offset = offset self._mime: Optional[Message] = None @@ -591,6 +601,7 @@ def __init__( """ self.mime = mime self.source_info: Optional[SourceInfo] = None + self.comments: Tuple[Comment, ...] = tuple(comments) def __init_subclass__(cls, **kwargs): TEST_TYPES.add(cls) @@ -2066,12 +2077,24 @@ def _parse_file( level_zero_tests: List[MagicTest] = [] tests_with_mime: Set[MagicTest] = set() indirect_tests: Set[IndirectTest] = set() + comments: List[Comment] = [] with open(def_file, "rb") as f: for line_number, raw_line in enumerate(f.readlines()): line_number += 1 raw_line = raw_line.lstrip() - if not raw_line or raw_line.startswith(b"#"): - # skip empty lines and comments + if not raw_line: + # skip empty lines + comments = [] + continue + elif raw_line.startswith(b"#"): + # this is a comment + try: + comments.append(Comment( + message=raw_line[1:].strip().decode("utf-8"), + source_info=SourceInfo(def_file, line_number, raw_line.decode("utf-8")) + )) + except UnicodeDecodeError: + pass continue elif raw_line.startswith(b"!:apple") or raw_line.startswith(b"!:strength"): # ignore these directives for now @@ -2180,6 +2203,8 @@ def __init__(self): if test.level == 0: level_zero_tests.append(test) test.source_info = SourceInfo(def_file, line_number, line) + test.comments = tuple(comments) + comments = [] current_test = test continue m = MIME_PATTERN.match(line) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index c38fa699..093715f4 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -269,21 +269,32 @@ def should_break(self) -> bool: ) def write_test(self, test: MagicTest, is_current_test: bool = False): + for comment in test.comments: + if comment.source_info is not None and comment.source_info.original_line is not None: + self.write(f" {comment.source_info.path.name}:{comment.source_info.line}\t", dim=True) + self.write(comment.source_info.original_line.strip(), dim=True) + self.write("\n") + else: + self.write(f" # {comment!s}\n", dim=True) if is_current_test: self.write("→ ", bold=True) else: self.write(" ") if test.source_info is not None and test.source_info.original_line is not None: - self.write(f"{test.source_info.path.name}:{test.source_info.line} ", dim=True) + source_prefix = f"{test.source_info.path.name}:{test.source_info.line}" + indent = f"{' ' * len(source_prefix)}\t" + self.write(source_prefix, dim=True) + self.write("\t") self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE) else: + indent = "" self.write(f"{'>' * test.level}{test.offset!s}\t") self.write(test.message, color=ANSIColor.BLUE) if test.mime is not None: - self.write("\n !:mime ", dim=True) + self.write(f"\n {indent}!:mime ", dim=True) self.write(test.mime, color=ANSIColor.BLUE) for e in test.extensions: - self.write("\n !:ext ", dim=True) + self.write(f"\n {indent}!:ext ", dim=True) self.write(str(e), color=ANSIColor.BLUE) self.write("\n") From 680a6bf3649cf38a0f8d139b0b41c22426c8f9a4 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 7 Dec 2021 10:51:22 -0500 Subject: [PATCH 15/18] Nicer text highlighting --- polyfile/magic_debugger.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 093715f4..07f83a8d 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -5,7 +5,7 @@ from .polyfile import __copyright__, __license__, __version__ from .logger import getStatusLogger -from .magic import MagicTest, TestResult, TEST_TYPES +from .magic import MagicTest, SourceInfo, TestResult, TEST_TYPES from .wildcards import Wildcard @@ -271,7 +271,9 @@ def should_break(self) -> bool: def write_test(self, test: MagicTest, is_current_test: bool = False): for comment in test.comments: if comment.source_info is not None and comment.source_info.original_line is not None: - self.write(f" {comment.source_info.path.name}:{comment.source_info.line}\t", dim=True) + self.write(f" {comment.source_info.path.name}", dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(f"{comment.source_info.line}\t", dim=True, color=ANSIColor.CYAN) self.write(comment.source_info.original_line.strip(), dim=True) self.write("\n") else: @@ -283,13 +285,15 @@ def write_test(self, test: MagicTest, is_current_test: bool = False): if test.source_info is not None and test.source_info.original_line is not None: source_prefix = f"{test.source_info.path.name}:{test.source_info.line}" indent = f"{' ' * len(source_prefix)}\t" - self.write(source_prefix, dim=True) + self.write(test.source_info.path.name, dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(test.source_info.line, dim=True, color=ANSIColor.CYAN) self.write("\t") - self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE) + self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE, bold=True) else: indent = "" self.write(f"{'>' * test.level}{test.offset!s}\t") - self.write(test.message, color=ANSIColor.BLUE) + self.write(test.message, color=ANSIColor.BLUE, bold=True) if test.mime is not None: self.write(f"\n {indent}!:mime ", dim=True) self.write(test.mime, color=ANSIColor.BLUE) From 49857dcf1018b5c465f1e0c9a82bf671e0a146d7 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 7 Dec 2021 12:04:37 -0500 Subject: [PATCH 16/18] Propagate test failure messages to the debugger --- polyfile/magic.py | 185 ++++++++++++++++++++++++++----------- polyfile/magic_debugger.py | 16 ++-- 2 files changed, 143 insertions(+), 58 deletions(-) diff --git a/polyfile/magic.py b/polyfile/magic.py index 89e798a6..67dbed80 100644 --- a/polyfile/magic.py +++ b/polyfile/magic.py @@ -120,20 +120,12 @@ def unescape(to_unescape: Union[str, bytes]) -> bytes: return bytes(b) -class TestResult: - def __init__( - self, test: "MagicTest", - value: Any, - offset: int, - length: int, - parent: Optional["TestResult"] = None - ): +class TestResult(ABC): + def __init__(self, test: "MagicTest", offset: int, parent: Optional["TestResult"] = None): self.test: MagicTest = test - self.value: Any = value self.offset: int = offset - self.length: int = length - self.parent: Optional[TestResult] = parent - if parent is not None: + self.parent: Optional["TestResult"] = parent + if parent is not None and bool(self): assert self.test.named_test is self.test or parent.test.level == self.test.level - 1 if not isinstance(self.test, UseTest): parent.child_matched = True @@ -152,13 +144,53 @@ def child_matched(self, did_match: bool): self.parent.parent.child_matched = True self._child_matched = did_match + def __hash__(self): + return hash((self.test, self.offset)) + + def __eq__(self, other): + return isinstance(other, TestResult) and other.test == self.test and other.offset == self.offset + + @abstractmethod + def __bool__(self): + raise NotImplementedError() + + def __repr__(self): + return f"{self.__class__.__name__}(test={self.test!r}, offset={self.offset}, parent={self.parent!r})" + + def __str__(self): + if self.test.message is not None: + # TODO: Fix pasting our value in + return str(self.test.message) + #if self.value is not None and "%" in self.test.message: + # return self.test.message % (self.value,) + #else: + # return self.test.message + else: + return f"Match[{self.offset}]" + + +class MatchedTest(TestResult): + def __init__( + self, test: "MagicTest", + value: Any, + offset: int, + length: int, + parent: Optional["TestResult"] = None + ): + super().__init__(test=test, offset=offset, parent=parent) + self.value: Any = value + self.length: int = length + def __hash__(self): return hash((self.test, self.offset, self.length)) def __eq__(self, other): - return isinstance(other, TestResult) and other.test == self.test and other.offset == self.offset \ + return isinstance(other, MatchedTest) and other.test == self.test and other.offset == self.offset \ and other.length == self.length + def __bool__(self): + return True + def __repr__(self): return f"{self.__class__.__name__}(test={self.test!r}, offset={self.offset}, length={self.length}, " \ f"parent={self.parent!r})" @@ -175,6 +207,15 @@ def __str__(self): return f"Match[{self.offset}:{self.offset + self.length}]" +class FailedTest(TestResult): + def __init__(self, test: "MagicTest", offset: int, message: str, parent: Optional["TestResult"] = None): + super().__init__(test=test, offset=offset, parent=parent) + self.message: str = message + + def __bool__(self): + return False + + class Endianness(Enum): NATIVE = "=" LITTLE = "<" @@ -703,23 +744,23 @@ def all_extensions(self) -> LazyIterableSet[str]: return LazyIterableSet(self._all_extensions()) @abstractmethod - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError() - def _match(self, context: MatchContext, parent_match: Optional[TestResult] = None) -> Iterator[TestResult]: + def _match(self, context: MatchContext, parent_match: Optional[TestResult] = None) -> Iterator[MatchedTest]: if context.only_match_mime and not self.can_match_mime: return try: absolute_offset = self.offset.to_absolute(context.data, parent_match) except InvalidOffsetError: - return None + return m = self.test(context.data, absolute_offset, parent_match) - if logging.root.level <= TRACE and (m is not None or self.level > 0): + if logging.root.level <= TRACE and (bool(m) or self.level > 0): log.trace( - f"{self.source_info!s}\t{m is not None}\t{absolute_offset}\t" + f"{self.source_info!s}\t{bool(m)}\t{absolute_offset}\t" f"{context.data[absolute_offset:absolute_offset + 20]!r}" ) - if m is not None: + if bool(m): if not context.only_match_mime or self.mime is not None: yield m for child in self.children: @@ -956,6 +997,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return self.post_process(data) + def __str__(self): + return "null-terminated string" + class NegatedStringTest(StringWildcard): def __init__(self, parent_test: StringTest): @@ -969,6 +1013,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return DataTypeMatch.INVALID + def __str__(self): + return f"something other than {self.parent!s}" + class StringLengthTest(StringWildcard): def __init__(self, to_match: bytes, test_smaller: bool, trim: bool = False, compact_whitespace: bool = False): @@ -985,6 +1032,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return DataTypeMatch.INVALID + def __str__(self): + return repr(self.to_match) + class StringMatch(StringTest): def __init__(self, @@ -1035,6 +1085,9 @@ def matches(self, data: bytes) -> DataTypeMatch: return DataTypeMatch.INVALID return self.post_process(bytes(matched)) + def __str__(self): + return repr(self.string) + class StringType(DataType[StringTest]): def __init__( @@ -1647,13 +1700,18 @@ def __init__( self.data_type: DataType[T] = data_type self.constant: T = constant - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: match = self.data_type.match(data[absolute_offset:], self.constant) if match: - return TestResult(self, offset=absolute_offset + match.initial_offset, length=len(match.raw_match), - value=match.value, parent=parent_match) + return MatchedTest(self, offset=absolute_offset + match.initial_offset, length=len(match.raw_match), + value=match.value, parent=parent_match) else: - return None + return FailedTest( + self, + offset=absolute_offset, + parent=parent_match, + message=f"expected {self.constant!s}" + ) class OffsetMatchTest(MagicTest): @@ -1669,14 +1727,19 @@ def __init__( super().__init__(offset=offset, mime=mime, extensions=extensions, message=message, parent=parent) self.value: IntegerValue = value - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if self.value.test(absolute_offset, unsigned=True, num_bytes=8): - return TestResult(self, offset=0, length=absolute_offset, value=absolute_offset, parent=parent_match) + return MatchedTest(self, offset=0, length=absolute_offset, value=absolute_offset, parent=parent_match) else: - return None + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=f"expected {self.value!r}" + ) -class IndirectResult(TestResult): +class IndirectResult(MatchedTest): def __init__(self, test: "IndirectTest", offset: int, parent: Optional[TestResult] = None): super().__init__(test, value=None, offset=offset, length=0, parent=parent) @@ -1703,10 +1766,16 @@ def __init__( p.can_match_mime = True p = p.parent - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[IndirectResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if self.relative: if parent_match is None: - return None + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message="the test is relative but it does not have a parent test (this is likely a bug in the magic" + " definition file)" + ) absolute_offset += parent_match.offset return IndirectResult(self, absolute_offset, parent_match) @@ -1735,10 +1804,10 @@ def to_absolute(self, data: bytes, last_match: Optional[TestResult]) -> int: self.named_test = self self.used_by: Set[UseTest] = set() - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> MatchedTest: if parent_match is not None: - return TestResult(self, offset=parent_match.offset + parent_match.length, length=0, value=self.name, - parent=parent_match) + return MatchedTest(self, offset=parent_match.offset + parent_match.length, length=0, value=self.name, + parent=parent_match) else: raise ValueError("A named test must always be called from a `use` test.") @@ -1778,7 +1847,7 @@ def _match(self, context: MatchContext, parent_match: Optional[TestResult] = Non log.trace( f"{self.source_info!s}\tTrue\t{absolute_offset}\t{context.data[absolute_offset:absolute_offset + 20]!r}" ) - use_match = TestResult(self, None, absolute_offset, 0, parent=parent_match) + use_match = MatchedTest(self, None, absolute_offset, 0, parent=parent_match) yielded = False for named_result in self.referenced_test._match(context, use_match): if not yielded: @@ -1795,7 +1864,7 @@ def _match(self, context: MatchContext, parent_match: Optional[TestResult] = Non if not context.only_match_mime or child.can_match_mime: yield from child._match(context=context, parent_match=use_match) - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError("This function should never be called") @@ -1803,18 +1872,23 @@ class JSONTest(MagicTest): def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: try: parsed = json.loads(data[absolute_offset:]) - return TestResult(self, offset=absolute_offset, length=len(data) - absolute_offset, value=parsed, - parent=parent_match) - except (json.JSONDecodeError, UnicodeDecodeError): - return None + return MatchedTest(self, offset=absolute_offset, length=len(data) - absolute_offset, value=parsed, + parent=parent_match) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=str(e) + ) class CSVTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: try: text = data[absolute_offset:].decode("utf-8") - except UnicodeDecodeError: - return None + except UnicodeDecodeError as e: + return FailedTest(test=self, offset=absolute_offset, parent=parent_match, message=str(e)) for dialect in csv.list_dialects(): string_data = StringIO(text, newline="") reader = csv.reader(string_data, dialect=dialect) @@ -1835,30 +1909,36 @@ def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestRes continue if valid: # every row was valid, and we had at least one row - return TestResult(self, offset=absolute_offset, length=len(data) - absolute_offset, value=dialect, - parent=parent_match) - return None + return MatchedTest(self, offset=absolute_offset, length=len(data) - absolute_offset, value=dialect, + parent=parent_match) + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=f"the input did not match a known CSV dialect ({', '.join(csv.list_dialects())})" + ) class DefaultTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if parent_match is None or not parent_match.child_matched: - return TestResult(self, offset=absolute_offset, length=0, value=True, parent=parent_match) + return MatchedTest(self, offset=absolute_offset, length=0, value=True, parent=parent_match) else: - return None + return FailedTest(self, offset=absolute_offset, parent=parent_match, message="the parent test already " + "has a child that matched") class ClearTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> MatchedTest: if parent_match is None: - return TestResult(self, offset=absolute_offset, length=0, value=None) + return MatchedTest(self, offset=absolute_offset, length=0, value=None) else: parent_match.child_matched = False - return TestResult(self, offset=absolute_offset, length=0, parent=parent_match, value=None) + return MatchedTest(self, offset=absolute_offset, length=0, parent=parent_match, value=None) class DERTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError( "TODO: Implement support for the DER test (e.g., using the Kaitai asn1_der.py parser)" ) @@ -1921,7 +2001,8 @@ def __bool__(self): def __len__(self): if self._result_iter is not None: # we have not yet finished collecting the results - for _ in self: pass + for _ in self: + pass assert self._result_iter is None return len(self._results) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 07f83a8d..4b27453c 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -5,7 +5,7 @@ from .polyfile import __copyright__, __license__, __version__ from .logger import getStatusLogger -from .magic import MagicTest, SourceInfo, TestResult, TEST_TYPES +from .magic import FailedTest, MagicTest, TestResult, TEST_TYPES from .wildcards import Wildcard @@ -262,7 +262,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def should_break(self) -> bool: return self.step_mode == StepMode.SINGLE_STEPPING or ( - self.step_mode == StepMode.NEXT and self.last_result is not None + self.step_mode == StepMode.NEXT and self.last_result ) or any( b.should_break(self.last_test, self.data, self.last_offset, self.last_parent_match) for b in self.breakpoints @@ -387,10 +387,14 @@ def print_where(self): self.write_test(t) self.write("\n") self.print_context(self.data, self.last_offset) - if self.last_result is None: - self.write("Test failed.\n", color=ANSIColor.RED) - else: - self.write("Test succeeded.\n", color=ANSIColor.GREEN) + if self.last_result is not None: + if not self.last_result: + self.write("Test failed.\n", color=ANSIColor.RED) + if isinstance(self.last_result, FailedTest): + self.write(self.last_result.message) + self.write("\n") + else: + self.write("Test succeeded.\n", color=ANSIColor.GREEN) def repl(self): log.clear_status() From 0850a3eeba75dfd9f951b0efeaac148a78b616c6 Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Tue, 11 Jan 2022 17:41:16 -0500 Subject: [PATCH 17/18] Adds the ability to use PDB to debug custom matchers --- polyfile/magic_debugger.py | 129 +++++++++++++++++++++++++++++++++++-- 1 file changed, 125 insertions(+), 4 deletions(-) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 4b27453c..73c17af5 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod from enum import Enum +from pdb import Pdb import sys -from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Iterator, List, Optional, Type, TypeVar, Union -from .polyfile import __copyright__, __license__, __version__ +from .polyfile import __copyright__, __license__, __version__, CUSTOM_MATCHERS, Match, Submatch from .logger import getStatusLogger from .magic import FailedTest, MagicTest, TestResult, TEST_TYPES from .wildcards import Wildcard @@ -185,12 +186,36 @@ def enabled(self) -> bool: return self.original_test is not None def uninstrument(self): - if self.original_test is not None and self.test.test is self: + if self.original_test is not None: # we are still assigned to the test function, so reset it self.test.test = self.original_test self.original_test = None +class InstrumentedMatch: + def __init__(self, match: Type[Match], debugger: "Debugger"): + self.match: Type[Match] = match + self.debugger: Debugger = debugger + if hasattr(match, "submatch"): + self.original_submatch: Optional[Callable[...], Iterator[Submatch]] = match.submatch + + def wrapper(match_instance, *args, **kwargs) -> Iterator[Submatch]: + yield from self.debugger.debug_submatch(self, match_instance, *args, **kwargs) + + match.submatch = wrapper + else: + self.original_submatch = None + + @property + def enalbed(self) -> bool: + return self.original_submatch is not None + + def uninstrument(self): + if self.original_submatch is not None: + self.match.submatch = self.original_submatch + self.original_submatch = None + + def string_escape(data: Union[bytes, int]) -> str: if not isinstance(data, int): return "".join(string_escape(d) for d in data) @@ -217,7 +242,7 @@ class StepMode(Enum): class Debugger: - def __init__(self): + def __init__(self, break_on_submatching: bool = True): self.instrumented_tests: List[InstrumentedTest] = [] self.breakpoints: List[Breakpoint] = [] self._entries: int = 0 @@ -228,6 +253,9 @@ def __init__(self): self.data: bytes = b"" self.last_offset: int = 0 self.last_result: Optional[TestResult] = None + self.instrumented_matches: List[InstrumentedMatch] =[] + self.break_on_submatching: bool = break_on_submatching + self._pdb: Optional[Pdb] = None @property def enabled(self) -> bool: @@ -239,12 +267,19 @@ def enabled(self, is_enabled: bool): for t in self.instrumented_tests: t.uninstrument() self.instrumented_tests = [] + for m in self.instrumented_matches: + m.uninstrument() + self.instrumented_matches = [] if is_enabled: # Instrument all of the MagicTest.test functions: for test in TEST_TYPES: if "test" in test.__dict__: # this class actually implements the test() function self.instrumented_tests.append(InstrumentedTest(test, self)) + if self.break_on_submatching: + for match in CUSTOM_MATCHERS.values(): + if hasattr(match, "submatch"): + self.instrumented_matches.append(InstrumentedMatch(match, self)) self.write(f"PolyFile {__version__}\n", color=ANSIColor.MAGENTA, bold=True) self.write(f"{__copyright__}\n{__license__}\n\nFor help, type \"help\".\n") self.repl() @@ -319,6 +354,25 @@ def write(self, message: Any, bold: bool = False, dim: bool = False, color: Opti return sys.stdout.write(str(message)) + def prompt(self, message: str, default: bool = True) -> bool: + while True: + self.write(f"{message} ", bold=True) + self.write("[", dim=True) + if default: + self.write("Y", bold=True, color=ANSIColor.GREEN) + self.write("n", dim=True, color=ANSIColor.RED) + else: + self.write("y", dim=True, color=ANSIColor.GREEN) + self.write("N", bold=True, color=ANSIColor.RED) + self.write("] ", dim=True) + answer = input().strip().lower() + if not answer: + return default + elif answer == "n": + return False + elif answer == "y": + return True + def print_context(self, data: bytes, offset: int, context_bytes: int = 32): bytes_before = min(offset, context_bytes) context_before = string_escape(data[:bytes_before]) @@ -355,6 +409,73 @@ def debug( self.repl() return self.last_result + def print_match(self, match: Match): + obj = match.to_obj() + self.write("{\n", bold=True) + for key, value in obj.items(): + if isinstance(value, list): + # TODO: Maybe implement list printing later. + # I don't think there will be lists here currently, thouh. + continue + self.write(f" {key!r}", color=ANSIColor.BLUE) + self.write(": ", bold=True) + if isinstance(value, int) or isinstance(value, float): + self.write(str(value)) + else: + self.write(repr(value), color=ANSIColor.GREEN) + self.write(",\n", bold=True) + self.write("}\n", bold=True) + + def debug_submatch(self, instrumented_match: InstrumentedMatch, match: Match, file_stream) -> Iterator[Submatch]: + log.clear_status() + + if instrumented_match.original_submatch is None: + submatch = instrumented_match.match.submatch + else: + submatch = instrumented_match.original_submatch + + def print_location(): + self.write(f"{file_stream.name}", dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(f"{file_stream.tell()} ", dim=True, color=ANSIColor.CYAN) + + if self._pdb is not None: + # We are already debugging! + print_location() + self.write(f"Parsing for submatches using {instrumented_match.match.__name__}.\n") + yield from submatch(match, file_stream) + return + self.print_match(match) + print_location() + self.write(f"About to parse for submatches using {instrumented_match.match.__name__}.\n") + if not self.prompt("Debug using PDB?", default=False): + yield from submatch(match, file_stream) + return + try: + self._pdb = Pdb(skip=["polyfile.magic_debugger", "polyfile.magic"]) + self._pdb.prompt = "\u001b[1m(polyfile-Pdb)\u001b[0m " + generator = submatch(match, file_stream) + while True: + try: + result = self._pdb.runcall(next, generator) + self.write(f"Got a submatch:\n", dim=True) + self.print_match(result) + yield result + except StopIteration: + self.write(f"Yielded all submatches from {match.__class__.__name__} at offset {match.offset}.\n") + break + print_location() + if not self.prompt("Continue debugging the next submatch?", default=True): + if self.prompt("Print the remaining submatches?", default=False): + for result in generator: + self.print_match(result) + yield result + else: + yield from generator + break + finally: + self._pdb = None + def print_where(self): if self.last_test is None: self.write("The first test has not yet been run.\n", color=ANSIColor.RED) From 6408ef5a98c402e390adb736adc2b824883a5d0a Mon Sep 17 00:00:00 2001 From: Evan Sultanik Date: Thu, 13 Jan 2022 10:33:05 -0500 Subject: [PATCH 18/18] Adds a command line option to not debug with PDB --- polyfile/__main__.py | 7 ++++++- polyfile/magic_debugger.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/polyfile/__main__.py b/polyfile/__main__.py index 171be146..55dc9338 100644 --- a/polyfile/__main__.py +++ b/polyfile/__main__.py @@ -58,6 +58,9 @@ def main(argv=None): parser.add_argument('--trace', '-dd', action='store_true', help='print extra verbose debug information') parser.add_argument('--debugger', '-db', action='store_true', help='drop into an interactive debugger for libmagic ' 'file definition matching') + parser.add_argument('--no-debug-python', action='store_true', help='by default, the `--debugger` option will break ' + 'on custom matchers and prompt to debug using ' + 'PDB. This option will suppress those prompts.') parser.add_argument('--quiet', '-q', action='store_true', help='suppress all log output (overrides --debug)') parser.add_argument('--version', '-v', action='store_true', help='print PolyFile\'s version information to STDERR') parser.add_argument('-dumpversion', action='store_true', @@ -116,7 +119,9 @@ def main(argv=None): with path_or_stdin as file_path, ExitStack() as stack: if args.debugger: - stack.enter_context(Debugger()) + stack.enter_context(Debugger(break_on_submatching=not args.no_debug_python)) + elif args.no_debug_python: + log.warning("Ignoring `--no-debug-python`; it can only be used with the --debugger option.") matches = [] try: if args.only_match_mime: diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py index 73c17af5..241095a1 100644 --- a/polyfile/magic_debugger.py +++ b/polyfile/magic_debugger.py @@ -2,7 +2,7 @@ from enum import Enum from pdb import Pdb import sys -from typing import Any, Callable, Iterator, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, ContextManager, Iterator, List, Optional, Type, TypeVar, Union from .polyfile import __copyright__, __license__, __version__, CUSTOM_MATCHERS, Match, Submatch from .logger import getStatusLogger @@ -241,7 +241,7 @@ class StepMode(Enum): NEXT = 2 -class Debugger: +class Debugger(ContextManager["Debugger"]): def __init__(self, break_on_submatching: bool = True): self.instrumented_tests: List[InstrumentedTest] = [] self.breakpoints: List[Breakpoint] = []