diff --git a/polyfile/__main__.py b/polyfile/__main__.py index 41e80455..55dc9338 100644 --- a/polyfile/__main__.py +++ b/polyfile/__main__.py @@ -1,4 +1,5 @@ import argparse +from contextlib import ExitStack import base64 import hashlib import json @@ -14,6 +15,7 @@ from . import polyfile from .fileutils import PathOrStdin from .magic import MagicMatcher, MatchContext +from .magic_debugger import Debugger from .polyfile import __version__ @@ -54,6 +56,11 @@ def main(argv=None): help='stop scanning after having found this many matches') parser.add_argument('--debug', '-d', action='store_true', help='print debug information') parser.add_argument('--trace', '-dd', action='store_true', help='print extra verbose debug information') + parser.add_argument('--debugger', '-db', action='store_true', help='drop into an interactive debugger for libmagic ' + 'file definition matching') + parser.add_argument('--no-debug-python', action='store_true', help='by default, the `--debugger` option will break ' + 'on custom matchers and prompt to debug using ' + 'PDB. This option will suppress those prompts.') parser.add_argument('--quiet', '-q', action='store_true', help='suppress all log output (overrides --debug)') parser.add_argument('--version', '-v', action='store_true', help='print PolyFile\'s version information to STDERR') parser.add_argument('-dumpversion', action='store_true', @@ -110,7 +117,11 @@ def main(argv=None): exit(1) return # this is here because linters are dumb and will complain about the next line without it - with path_or_stdin as file_path: + with path_or_stdin as file_path, ExitStack() as stack: + if args.debugger: + stack.enter_context(Debugger(break_on_submatching=not args.no_debug_python)) + elif args.no_debug_python: + log.warning("Ignoring `--no-debug-python`; it can only be used with the --debugger option.") matches = [] try: if args.only_match_mime: diff --git a/polyfile/iterators.py b/polyfile/iterators.py index 0bee3f68..9dac45e0 100644 --- a/polyfile/iterators.py +++ b/polyfile/iterators.py @@ -66,4 +66,7 @@ def __init__(self, source: Iterable[T]): super().__init__(unique(iter(source), elements=self._set)) def __contains__(self, x: object) -> bool: + if x in self._set: + return True + self._complete() return x in self._set diff --git a/polyfile/magic.py b/polyfile/magic.py index f00ed8d8..67dbed80 100644 --- a/polyfile/magic.py +++ b/polyfile/magic.py @@ -21,7 +21,7 @@ import sys from time import gmtime, localtime, strftime from typing import ( - Any, BinaryIO, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union + Any, BinaryIO, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union ) from uuid import UUID @@ -120,20 +120,12 @@ def unescape(to_unescape: Union[str, bytes]) -> bytes: return bytes(b) -class TestResult: - def __init__( - self, test: "MagicTest", - value: Any, - offset: int, - length: int, - parent: Optional["TestResult"] = None - ): +class TestResult(ABC): + def __init__(self, test: "MagicTest", offset: int, parent: Optional["TestResult"] = None): self.test: MagicTest = test - self.value: Any = value self.offset: int = offset - self.length: int = length - self.parent: Optional[TestResult] = parent - if parent is not None: + self.parent: Optional["TestResult"] = parent + if parent is not None and bool(self): assert self.test.named_test is self.test or parent.test.level == self.test.level - 1 if not isinstance(self.test, UseTest): parent.child_matched = True @@ -152,13 +144,53 @@ def child_matched(self, did_match: bool): self.parent.parent.child_matched = True self._child_matched = did_match + def __hash__(self): + return hash((self.test, self.offset)) + + def __eq__(self, other): + return isinstance(other, TestResult) and other.test == self.test and other.offset == self.offset + + @abstractmethod + def __bool__(self): + raise NotImplementedError() + + def __repr__(self): + return f"{self.__class__.__name__}(test={self.test!r}, offset={self.offset}, parent={self.parent!r})" + + def __str__(self): + if self.test.message is not None: + # TODO: Fix pasting our value in + return str(self.test.message) + #if self.value is not None and "%" in self.test.message: + # return self.test.message % (self.value,) + #else: + # return self.test.message + else: + return f"Match[{self.offset}]" + + +class MatchedTest(TestResult): + def __init__( + self, test: "MagicTest", + value: Any, + offset: int, + length: int, + parent: Optional["TestResult"] = None + ): + super().__init__(test=test, offset=offset, parent=parent) + self.value: Any = value + self.length: int = length + def __hash__(self): return hash((self.test, self.offset, self.length)) def __eq__(self, other): - return isinstance(other, TestResult) and other.test == self.test and other.offset == self.offset \ + return isinstance(other, MatchedTest) and other.test == self.test and other.offset == self.offset \ and other.length == self.length + def __bool__(self): + return True + def __repr__(self): return f"{self.__class__.__name__}(test={self.test!r}, offset={self.offset}, length={self.length}, " \ f"parent={self.parent!r})" @@ -166,7 +198,7 @@ def __repr__(self): def __str__(self): if self.test.message is not None: # TODO: Fix pasting our value in - return self.test.message + return str(self.test.message) #if self.value is not None and "%" in self.test.message: # return self.test.message % (self.value,) #else: @@ -175,6 +207,15 @@ def __str__(self): return f"Match[{self.offset}:{self.offset + self.length}]" +class FailedTest(TestResult): + def __init__(self, test: "MagicTest", offset: int, message: str, parent: Optional["TestResult"] = None): + super().__init__(test=test, offset=offset, parent=parent) + self.message: str = message + + def __bool__(self): + return False + + class Endianness(Enum): NATIVE = "=" LITTLE = "<" @@ -543,6 +584,18 @@ def __str__(self): return f"${{x?{self.true_value}:{self.false_value}}}" +TEST_TYPES: Set[Type["MagicTest"]] = set() + + +class Comment: + def __init__(self, message: str, source_info: Optional[SourceInfo] = None): + self.message: str = message + self.source_info: Optional[SourceInfo] = source_info + + def __str__(self): + return self.message + + class MagicTest(ABC): def __init__( self, @@ -550,7 +603,8 @@ def __init__( mime: Optional[Union[str, TernaryExecutableMessage]] = None, extensions: Iterable[str] = (), message: Union[str, Message] = "", - parent: Optional["MagicTest"] = None + parent: Optional["MagicTest"] = None, + comments: Iterable[Comment] = () ): self.offset: Offset = offset self._mime: Optional[Message] = None @@ -588,6 +642,11 @@ def __init__( """ self.mime = mime self.source_info: Optional[SourceInfo] = None + self.comments: Tuple[Comment, ...] = tuple(comments) + + def __init_subclass__(cls, **kwargs): + TEST_TYPES.add(cls) + return super().__init_subclass__(**kwargs) @property def parent(self) -> Optional["MagicTest"]: @@ -685,23 +744,23 @@ def all_extensions(self) -> LazyIterableSet[str]: return LazyIterableSet(self._all_extensions()) @abstractmethod - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError() - def _match(self, context: MatchContext, parent_match: Optional[TestResult] = None) -> Iterator[TestResult]: + def _match(self, context: MatchContext, parent_match: Optional[TestResult] = None) -> Iterator[MatchedTest]: if context.only_match_mime and not self.can_match_mime: return try: absolute_offset = self.offset.to_absolute(context.data, parent_match) except InvalidOffsetError: - return None + return m = self.test(context.data, absolute_offset, parent_match) - if logging.root.level <= TRACE and (m is not None or self.level > 0): + if logging.root.level <= TRACE and (bool(m) or self.level > 0): log.trace( - f"{self.source_info!s}\t{m is not None}\t{absolute_offset}\t" + f"{self.source_info!s}\t{bool(m)}\t{absolute_offset}\t" f"{context.data[absolute_offset:absolute_offset + 20]!r}" ) - if m is not None: + if bool(m): if not context.only_match_mime or self.mime is not None: yield m for child in self.children: @@ -718,7 +777,7 @@ def match(self, to_match: Union[bytes, BinaryIO, str, Path, MatchContext]) -> It def __str__(self): if self.source_info is not None and self.source_info.original_line is not None: - s = f"{self.source_info.path.name}:{self.source_info.line} {self.source_info.original_line}" + s = f"{self.source_info.path.name}:{self.source_info.line} {self.source_info.original_line.strip()}" else: s = f"{'>' * self.level}{self.offset!s}\t{self.message}" if self.mime is not None: @@ -938,6 +997,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return self.post_process(data) + def __str__(self): + return "null-terminated string" + class NegatedStringTest(StringWildcard): def __init__(self, parent_test: StringTest): @@ -951,6 +1013,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return DataTypeMatch.INVALID + def __str__(self): + return f"something other than {self.parent!s}" + class StringLengthTest(StringWildcard): def __init__(self, to_match: bytes, test_smaller: bool, trim: bool = False, compact_whitespace: bool = False): @@ -967,6 +1032,9 @@ def matches(self, data: bytes) -> DataTypeMatch: else: return DataTypeMatch.INVALID + def __str__(self): + return repr(self.to_match) + class StringMatch(StringTest): def __init__(self, @@ -1017,6 +1085,9 @@ def matches(self, data: bytes) -> DataTypeMatch: return DataTypeMatch.INVALID return self.post_process(bytes(matched)) + def __str__(self): + return repr(self.string) + class StringType(DataType[StringTest]): def __init__( @@ -1629,13 +1700,18 @@ def __init__( self.data_type: DataType[T] = data_type self.constant: T = constant - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: match = self.data_type.match(data[absolute_offset:], self.constant) if match: - return TestResult(self, offset=absolute_offset + match.initial_offset, length=len(match.raw_match), - value=match.value, parent=parent_match) + return MatchedTest(self, offset=absolute_offset + match.initial_offset, length=len(match.raw_match), + value=match.value, parent=parent_match) else: - return None + return FailedTest( + self, + offset=absolute_offset, + parent=parent_match, + message=f"expected {self.constant!s}" + ) class OffsetMatchTest(MagicTest): @@ -1651,14 +1727,19 @@ def __init__( super().__init__(offset=offset, mime=mime, extensions=extensions, message=message, parent=parent) self.value: IntegerValue = value - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if self.value.test(absolute_offset, unsigned=True, num_bytes=8): - return TestResult(self, offset=0, length=absolute_offset, value=absolute_offset, parent=parent_match) + return MatchedTest(self, offset=0, length=absolute_offset, value=absolute_offset, parent=parent_match) else: - return None + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=f"expected {self.value!r}" + ) -class IndirectResult(TestResult): +class IndirectResult(MatchedTest): def __init__(self, test: "IndirectTest", offset: int, parent: Optional[TestResult] = None): super().__init__(test, value=None, offset=offset, length=0, parent=parent) @@ -1685,10 +1766,16 @@ def __init__( p.can_match_mime = True p = p.parent - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[IndirectResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if self.relative: if parent_match is None: - return None + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message="the test is relative but it does not have a parent test (this is likely a bug in the magic" + " definition file)" + ) absolute_offset += parent_match.offset return IndirectResult(self, absolute_offset, parent_match) @@ -1717,10 +1804,10 @@ def to_absolute(self, data: bytes, last_match: Optional[TestResult]) -> int: self.named_test = self self.used_by: Set[UseTest] = set() - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> MatchedTest: if parent_match is not None: - return TestResult(self, offset=parent_match.offset + parent_match.length, length=0, value=self.name, - parent=parent_match) + return MatchedTest(self, offset=parent_match.offset + parent_match.length, length=0, value=self.name, + parent=parent_match) else: raise ValueError("A named test must always be called from a `use` test.") @@ -1760,7 +1847,7 @@ def _match(self, context: MatchContext, parent_match: Optional[TestResult] = Non log.trace( f"{self.source_info!s}\tTrue\t{absolute_offset}\t{context.data[absolute_offset:absolute_offset + 20]!r}" ) - use_match = TestResult(self, None, absolute_offset, 0, parent=parent_match) + use_match = MatchedTest(self, None, absolute_offset, 0, parent=parent_match) yielded = False for named_result in self.referenced_test._match(context, use_match): if not yielded: @@ -1777,7 +1864,7 @@ def _match(self, context: MatchContext, parent_match: Optional[TestResult] = Non if not context.only_match_mime or child.can_match_mime: yield from child._match(context=context, parent_match=use_match) - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError("This function should never be called") @@ -1785,18 +1872,23 @@ class JSONTest(MagicTest): def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: try: parsed = json.loads(data[absolute_offset:]) - return TestResult(self, offset=absolute_offset, length=len(data) - absolute_offset, value=parsed, - parent=parent_match) - except (json.JSONDecodeError, UnicodeDecodeError): - return None + return MatchedTest(self, offset=absolute_offset, length=len(data) - absolute_offset, value=parsed, + parent=parent_match) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=str(e) + ) class CSVTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: try: text = data[absolute_offset:].decode("utf-8") - except UnicodeDecodeError: - return None + except UnicodeDecodeError as e: + return FailedTest(test=self, offset=absolute_offset, parent=parent_match, message=str(e)) for dialect in csv.list_dialects(): string_data = StringIO(text, newline="") reader = csv.reader(string_data, dialect=dialect) @@ -1817,30 +1909,36 @@ def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestRes continue if valid: # every row was valid, and we had at least one row - return TestResult(self, offset=absolute_offset, length=len(data) - absolute_offset, value=dialect, - parent=parent_match) - return None + return MatchedTest(self, offset=absolute_offset, length=len(data) - absolute_offset, value=dialect, + parent=parent_match) + return FailedTest( + test=self, + offset=absolute_offset, + parent=parent_match, + message=f"the input did not match a known CSV dialect ({', '.join(csv.list_dialects())})" + ) class DefaultTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: if parent_match is None or not parent_match.child_matched: - return TestResult(self, offset=absolute_offset, length=0, value=True, parent=parent_match) + return MatchedTest(self, offset=absolute_offset, length=0, value=True, parent=parent_match) else: - return None + return FailedTest(self, offset=absolute_offset, parent=parent_match, message="the parent test already " + "has a child that matched") class ClearTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> MatchedTest: if parent_match is None: - return TestResult(self, offset=absolute_offset, length=0, value=None) + return MatchedTest(self, offset=absolute_offset, length=0, value=None) else: parent_match.child_matched = False - return TestResult(self, offset=absolute_offset, length=0, parent=parent_match, value=None) + return MatchedTest(self, offset=absolute_offset, length=0, parent=parent_match, value=None) class DERTest(MagicTest): - def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> Optional[TestResult]: + def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult: raise NotImplementedError( "TODO: Implement support for the DER test (e.g., using the Kaitai asn1_der.py parser)" ) @@ -1903,7 +2001,8 @@ def __bool__(self): def __len__(self): if self._result_iter is not None: # we have not yet finished collecting the results - for _ in self: pass + for _ in self: + pass assert self._result_iter is None return len(self._results) @@ -2059,12 +2158,24 @@ def _parse_file( level_zero_tests: List[MagicTest] = [] tests_with_mime: Set[MagicTest] = set() indirect_tests: Set[IndirectTest] = set() + comments: List[Comment] = [] with open(def_file, "rb") as f: for line_number, raw_line in enumerate(f.readlines()): line_number += 1 raw_line = raw_line.lstrip() - if not raw_line or raw_line.startswith(b"#"): - # skip empty lines and comments + if not raw_line: + # skip empty lines + comments = [] + continue + elif raw_line.startswith(b"#"): + # this is a comment + try: + comments.append(Comment( + message=raw_line[1:].strip().decode("utf-8"), + source_info=SourceInfo(def_file, line_number, raw_line.decode("utf-8")) + )) + except UnicodeDecodeError: + pass continue elif raw_line.startswith(b"!:apple") or raw_line.startswith(b"!:strength"): # ignore these directives for now @@ -2173,6 +2284,8 @@ def __init__(self): if test.level == 0: level_zero_tests.append(test) test.source_info = SourceInfo(def_file, line_number, line) + test.comments = tuple(comments) + comments = [] current_test = test continue m = MIME_PATTERN.match(line) diff --git a/polyfile/magic_debugger.py b/polyfile/magic_debugger.py new file mode 100644 index 00000000..241095a1 --- /dev/null +++ b/polyfile/magic_debugger.py @@ -0,0 +1,629 @@ +from abc import ABC, abstractmethod +from enum import Enum +from pdb import Pdb +import sys +from typing import Any, Callable, ContextManager, Iterator, List, Optional, Type, TypeVar, Union + +from .polyfile import __copyright__, __license__, __version__, CUSTOM_MATCHERS, Match, Submatch +from .logger import getStatusLogger +from .magic import FailedTest, MagicTest, TestResult, TEST_TYPES +from .wildcards import Wildcard + + +log = getStatusLogger("polyfile") + + +class ANSIColor(Enum): + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + + def to_code(self) -> str: + return f"\u001b[{self.value}m" + + +B = TypeVar("B", bound="Breakpoint") + + +BREAKPOINT_TYPES: List[Type["Breakpoint"]] = [] + + +class Breakpoint(ABC): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + BREAKPOINT_TYPES.append(cls) + + @abstractmethod + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + raise NotImplementedError() + + @classmethod + @abstractmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + raise NotImplementedError() + + @classmethod + @abstractmethod + def usage(cls) -> str: + raise NotImplementedError() + + @abstractmethod + def __str__(self): + raise NotImplementedError() + + +class MimeBreakpoint(Breakpoint): + def __init__(self, mimetype: str): + self.mimetype: str = mimetype + self.pattern: Wildcard = Wildcard.parse(mimetype) + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + return self.pattern.is_contained_in(test.mimetypes()) + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + if command.lower().startswith("mime:"): + return MimeBreakpoint(command[len("mime:"):]) + return None + + @classmethod + def usage(cls) -> str: + return "`b MIME:MIMETYPE` to break when a test is capable of matching that mimetype. "\ + "The MIMETYPE can include the '*' and '?' wildcards. For example, " \ + "\"b MIME:application/pdf\" and \"b MIME:*pdf\"." + + def __str__(self): + return f"Breakpoint: Matching for MIME {self.mimetype}" + + +class ExtensionBreakpoint(Breakpoint): + def __init__(self, ext: str): + self.ext: str = ext + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + return self.ext in test.all_extensions() + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + if command.lower().startswith("ext:"): + return ExtensionBreakpoint(command[len("ext:"):]) + return None + + @classmethod + def usage(cls) -> str: + return "`b EXT:EXTENSION` to break when a test is capable of matching that extension. For example, " \ + "\"b EXT:pdf\"." + + def __str__(self): + return f"Breakpoint: Matching for extension {self.ext}" + + +class FileBreakpoint(Breakpoint): + def __init__(self, filename: str, line: int): + self.filename: str = filename + self.line: int = line + + def should_break( + self, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> bool: + if test.source_info is None or test.source_info.line != self.line: + return False + if "/" in self.filename: + # it is a file path + return str(test.source_info.path) == self.filename + else: + # treat it like a filename + return test.source_info.path.name == self.filename + + @classmethod + def parse(cls: Type[B], command: str) -> Optional[B]: + filename, *remainder = command.split(":") + if not remainder: + return None + try: + line = int("".join(remainder)) + except ValueError: + return None + if line <= 0: + return None + return FileBreakpoint(filename, line) + + @classmethod + def usage(cls) -> str: + return "`b FILENAME:LINE_NO` to break when the line of the given magic file is reached. For example, " \ + "\"b archive:525\"." + + def __str__(self): + return f"Breakpoint: {self.filename} line {self.line}" + + +class InstrumentedTest: + def __init__(self, test: Type[MagicTest], debugger: "Debugger"): + self.test: Type[MagicTest] = test + self.debugger: Debugger = debugger + if "test" in test.__dict__: + self.original_test: Optional[Callable[[...], Optional[TestResult]]] = test.test + + def wrapper(test_instance, *args, **kwargs) -> Optional[TestResult]: + # if self.original_test is None: + # # this is a NOOP + # return self.test.test(test_instance, *args, **kwargs) + return self.debugger.debug(self, test_instance, *args, **kwargs) + + test.test = wrapper + else: + self.original_test = None + + @property + def enabled(self) -> bool: + return self.original_test is not None + + def uninstrument(self): + if self.original_test is not None: + # we are still assigned to the test function, so reset it + self.test.test = self.original_test + self.original_test = None + + +class InstrumentedMatch: + def __init__(self, match: Type[Match], debugger: "Debugger"): + self.match: Type[Match] = match + self.debugger: Debugger = debugger + if hasattr(match, "submatch"): + self.original_submatch: Optional[Callable[...], Iterator[Submatch]] = match.submatch + + def wrapper(match_instance, *args, **kwargs) -> Iterator[Submatch]: + yield from self.debugger.debug_submatch(self, match_instance, *args, **kwargs) + + match.submatch = wrapper + else: + self.original_submatch = None + + @property + def enalbed(self) -> bool: + return self.original_submatch is not None + + def uninstrument(self): + if self.original_submatch is not None: + self.match.submatch = self.original_submatch + self.original_submatch = None + + +def string_escape(data: Union[bytes, int]) -> str: + if not isinstance(data, int): + return "".join(string_escape(d) for d in data) + elif data == ord('\n'): + return "\\n" + elif data == ord('\t'): + return "\\t" + elif data == ord('\r'): + return "\\r" + elif data == 0: + return "\\0" + elif data == ord('\\'): + return "\\\\" + elif 32 <= data <= 126: + return chr(data) + else: + return f"\\x{data:02X}" + + +class StepMode(Enum): + RUNNING = 0 + SINGLE_STEPPING = 1 + NEXT = 2 + + +class Debugger(ContextManager["Debugger"]): + def __init__(self, break_on_submatching: bool = True): + self.instrumented_tests: List[InstrumentedTest] = [] + self.breakpoints: List[Breakpoint] = [] + self._entries: int = 0 + self.step_mode: StepMode = StepMode.RUNNING + self.last_command: Optional[str] = None + self.last_test: Optional[MagicTest] = None + self.last_parent_match: Optional[MagicTest] = None + self.data: bytes = b"" + self.last_offset: int = 0 + self.last_result: Optional[TestResult] = None + self.instrumented_matches: List[InstrumentedMatch] =[] + self.break_on_submatching: bool = break_on_submatching + self._pdb: Optional[Pdb] = None + + @property + def enabled(self) -> bool: + return any(t.enabled for t in self.instrumented_tests) + + @enabled.setter + def enabled(self, is_enabled: bool): + # Uninstrument any existing instrumentation: + for t in self.instrumented_tests: + t.uninstrument() + self.instrumented_tests = [] + for m in self.instrumented_matches: + m.uninstrument() + self.instrumented_matches = [] + if is_enabled: + # Instrument all of the MagicTest.test functions: + for test in TEST_TYPES: + if "test" in test.__dict__: + # this class actually implements the test() function + self.instrumented_tests.append(InstrumentedTest(test, self)) + if self.break_on_submatching: + for match in CUSTOM_MATCHERS.values(): + if hasattr(match, "submatch"): + self.instrumented_matches.append(InstrumentedMatch(match, self)) + self.write(f"PolyFile {__version__}\n", color=ANSIColor.MAGENTA, bold=True) + self.write(f"{__copyright__}\n{__license__}\n\nFor help, type \"help\".\n") + self.repl() + + def __enter__(self) -> "Debugger": + self._entries += 1 + if self._entries == 1: + self.enabled = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._entries -= 1 + if self._entries == 0: + self.enabled = False + + def should_break(self) -> bool: + return self.step_mode == StepMode.SINGLE_STEPPING or ( + self.step_mode == StepMode.NEXT and self.last_result + ) or any( + b.should_break(self.last_test, self.data, self.last_offset, self.last_parent_match) + for b in self.breakpoints + ) + + def write_test(self, test: MagicTest, is_current_test: bool = False): + for comment in test.comments: + if comment.source_info is not None and comment.source_info.original_line is not None: + self.write(f" {comment.source_info.path.name}", dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(f"{comment.source_info.line}\t", dim=True, color=ANSIColor.CYAN) + self.write(comment.source_info.original_line.strip(), dim=True) + self.write("\n") + else: + self.write(f" # {comment!s}\n", dim=True) + if is_current_test: + self.write("→ ", bold=True) + else: + self.write(" ") + if test.source_info is not None and test.source_info.original_line is not None: + source_prefix = f"{test.source_info.path.name}:{test.source_info.line}" + indent = f"{' ' * len(source_prefix)}\t" + self.write(test.source_info.path.name, dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(test.source_info.line, dim=True, color=ANSIColor.CYAN) + self.write("\t") + self.write(test.source_info.original_line.strip(), color=ANSIColor.BLUE, bold=True) + else: + indent = "" + self.write(f"{'>' * test.level}{test.offset!s}\t") + self.write(test.message, color=ANSIColor.BLUE, bold=True) + if test.mime is not None: + self.write(f"\n {indent}!:mime ", dim=True) + self.write(test.mime, color=ANSIColor.BLUE) + for e in test.extensions: + self.write(f"\n {indent}!:ext ", dim=True) + self.write(str(e), color=ANSIColor.BLUE) + self.write("\n") + + def write(self, message: Any, bold: bool = False, dim: bool = False, color: Optional[ANSIColor] = None): + if sys.stdout.isatty(): + if isinstance(message, MagicTest): + self.write_test(message) + return + prefixes: List[str] = [] + if bold and not dim: + prefixes.append("\u001b[1m") + elif dim and not bold: + prefixes.append("\u001b[2m") + if color is not None: + prefixes.append(color.to_code()) + if prefixes: + sys.stdout.write(f"{''.join(prefixes)}{message!s}\u001b[0m") + return + sys.stdout.write(str(message)) + + def prompt(self, message: str, default: bool = True) -> bool: + while True: + self.write(f"{message} ", bold=True) + self.write("[", dim=True) + if default: + self.write("Y", bold=True, color=ANSIColor.GREEN) + self.write("n", dim=True, color=ANSIColor.RED) + else: + self.write("y", dim=True, color=ANSIColor.GREEN) + self.write("N", bold=True, color=ANSIColor.RED) + self.write("] ", dim=True) + answer = input().strip().lower() + if not answer: + return default + elif answer == "n": + return False + elif answer == "y": + return True + + def print_context(self, data: bytes, offset: int, context_bytes: int = 32): + bytes_before = min(offset, context_bytes) + context_before = string_escape(data[:bytes_before]) + if 0 <= offset < len(data): + current_byte = string_escape(data[offset]) + else: + current_byte = "" + context_after = string_escape(data[offset + 1:offset + context_bytes]) + self.write(context_before) + self.write(current_byte, bold=True) + self.write(context_after) + self.write("\n") + self.write(f"{' ' * len(context_before)}") + self.write(f"{'^' * len(current_byte)}", bold=True) + self.write(f"{' ' * len(context_after)}\n") + + def debug( + self, + instrumented_test: InstrumentedTest, + test: MagicTest, + data: bytes, + absolute_offset: int, + parent_match: Optional[TestResult] + ) -> Optional[TestResult]: + self.last_test = test + self.data = data + self.last_offset = absolute_offset + self.last_parent_match = parent_match + if instrumented_test.original_test is None: + self.last_result = instrumented_test.test.test(test, data, absolute_offset, parent_match) + else: + self.last_result = instrumented_test.original_test(test, data, absolute_offset, parent_match) + if self.should_break(): + self.repl() + return self.last_result + + def print_match(self, match: Match): + obj = match.to_obj() + self.write("{\n", bold=True) + for key, value in obj.items(): + if isinstance(value, list): + # TODO: Maybe implement list printing later. + # I don't think there will be lists here currently, thouh. + continue + self.write(f" {key!r}", color=ANSIColor.BLUE) + self.write(": ", bold=True) + if isinstance(value, int) or isinstance(value, float): + self.write(str(value)) + else: + self.write(repr(value), color=ANSIColor.GREEN) + self.write(",\n", bold=True) + self.write("}\n", bold=True) + + def debug_submatch(self, instrumented_match: InstrumentedMatch, match: Match, file_stream) -> Iterator[Submatch]: + log.clear_status() + + if instrumented_match.original_submatch is None: + submatch = instrumented_match.match.submatch + else: + submatch = instrumented_match.original_submatch + + def print_location(): + self.write(f"{file_stream.name}", dim=True, color=ANSIColor.CYAN) + self.write(":", dim=True) + self.write(f"{file_stream.tell()} ", dim=True, color=ANSIColor.CYAN) + + if self._pdb is not None: + # We are already debugging! + print_location() + self.write(f"Parsing for submatches using {instrumented_match.match.__name__}.\n") + yield from submatch(match, file_stream) + return + self.print_match(match) + print_location() + self.write(f"About to parse for submatches using {instrumented_match.match.__name__}.\n") + if not self.prompt("Debug using PDB?", default=False): + yield from submatch(match, file_stream) + return + try: + self._pdb = Pdb(skip=["polyfile.magic_debugger", "polyfile.magic"]) + self._pdb.prompt = "\u001b[1m(polyfile-Pdb)\u001b[0m " + generator = submatch(match, file_stream) + while True: + try: + result = self._pdb.runcall(next, generator) + self.write(f"Got a submatch:\n", dim=True) + self.print_match(result) + yield result + except StopIteration: + self.write(f"Yielded all submatches from {match.__class__.__name__} at offset {match.offset}.\n") + break + print_location() + if not self.prompt("Continue debugging the next submatch?", default=True): + if self.prompt("Print the remaining submatches?", default=False): + for result in generator: + self.print_match(result) + yield result + else: + yield from generator + break + finally: + self._pdb = None + + def print_where(self): + if self.last_test is None: + self.write("The first test has not yet been run.\n", color=ANSIColor.RED) + self.write("Use `step`, `next`, or `run` to start testing.\n") + return + wrote_breakpoints = False + for b in self.breakpoints: + if b.should_break(self.last_test, self.data, self.last_offset, self.last_parent_match): + self.write(b, color=ANSIColor.MAGENTA) + self.write("\n") + wrote_breakpoints = True + if wrote_breakpoints: + self.write("\n") + test_stack = [self.last_test] + while test_stack[-1].parent is not None: + test_stack.append(test_stack[-1].parent) + for i, t in enumerate(reversed(test_stack)): + if i == len(test_stack) - 1: + self.write_test(t, is_current_test=True) + else: + self.write_test(t) + test_stack = list(reversed(self.last_test.children)) + descendants = [] + while test_stack: + descendant = test_stack.pop() + if descendant.can_match_mime: + descendants.append(descendant) + test_stack.extend(reversed(descendant.children)) + for t in descendants: + self.write_test(t) + self.write("\n") + self.print_context(self.data, self.last_offset) + if self.last_result is not None: + if not self.last_result: + self.write("Test failed.\n", color=ANSIColor.RED) + if isinstance(self.last_result, FailedTest): + self.write(self.last_result.message) + self.write("\n") + else: + self.write("Test succeeded.\n", color=ANSIColor.GREEN) + + def repl(self): + log.clear_status() + if self.last_test is not None: + self.print_where() + while True: + try: + self.write("(polyfile) ", bold=True) + sys.stdout.flush() + command = input() + except EOFError: + # the user pressed ^D to quit + exit(0) + if not command: + if self.last_command is None: + continue + command = self.last_command + command = command.lstrip() + space_index = command.find(" ") + if space_index > 0: + command, args = command[:space_index], command[space_index+1:].strip() + else: + args = "" + if "help".startswith(command): + usage = [ + ("help", "print this message"), + ("continue", "continue execution until the next breakpoint is hit"), + ("step", "step through a single magic test"), + ("next", "continue execution until the next test that matches"), + ("where", "print the context of the current magic test"), + ("breakpoint", "list the current breakpoints or add a new one"), + ("delete", "delete a breakpoint"), + ("quit", "exit the debugger"), + ] + aliases = { + "where": ("info stack", "backtrace") + } + left_col_width = max(len(u[0]) for u in usage) + left_col_width = max(left_col_width, max(len(c) for a in aliases.values() for c in a)) + left_col_width += 3 + for command, msg in usage: + self.write(command, bold=True, color=ANSIColor.BLUE) + self.write(f" {'.' * (left_col_width - len(command) - 2)} ") + self.write(msg) + if command in aliases: + self.write(" (aliases: ", dim=True) + alternatives = aliases[command] + for i, alt in enumerate(alternatives): + if i > 0 and len(alternatives) > 2: + self.write(", ", dim=True) + if i == len(alternatives) - 1 and len(alternatives) > 1: + self.write(" and ", dim=True) + self.write(alt, bold=True, color=ANSIColor.BLUE) + self.write(")", dim=True) + self.write("\n") + + elif "continue".startswith(command) or "run".startswith(command): + self.step_mode = StepMode.RUNNING + self.last_command = command + return + elif "step".startswith(command): + self.step_mode = StepMode.SINGLE_STEPPING + self.last_command = command + return + elif "next".startswith(command): + self.step_mode = StepMode.NEXT + self.last_command = command + return + elif "quit".startswith(command): + exit(0) + elif "delete".startswith(command): + if args: + try: + breakpoint_num = int(args) + except ValueError: + breakpoint_num = -1 + if not (0 <= breakpoint_num < len(self.breakpoints)): + print(f"Error: Invalid breakpoint \"{args}\"") + continue + b = self.breakpoints[breakpoint_num] + self.breakpoints = self.breakpoints[:breakpoint_num] + self.breakpoints[breakpoint_num + 1:] + self.write(f"Deleted {b!s}\n") + elif "breakpoint".startswith(command): + if args: + for b_type in BREAKPOINT_TYPES: + parsed = b_type.parse(args) + if parsed is not None: + self.write(parsed, color=ANSIColor.MAGENTA) + self.write("\n") + self.breakpoints.append(parsed) + break + else: + self.write("Error: Invalid breakpoint pattern\n", color=ANSIColor.RED) + else: + if self.breakpoints: + for i, b in enumerate(self.breakpoints): + self.write(f"{i}:\t", dim=True) + self.write(b, color=ANSIColor.MAGENTA) + self.write("\n") + else: + self.write("No breakpoints set.\n", color=ANSIColor.RED) + for b_type in BREAKPOINT_TYPES: + self.write(b_type.usage()) + self.write("\n") + elif "where".startswith(command) or "info stack".startswith(command) or "backtrace".startswith(command): + self.print_where() + else: + self.write(f"Undefined command: {command!r}. Try \"help\".\n", color=ANSIColor.RED) + self.last_command = None + continue + self.last_command = command diff --git a/polyfile/polyfile.py b/polyfile/polyfile.py index 64c1137a..226cf69e 100644 --- a/polyfile/polyfile.py +++ b/polyfile/polyfile.py @@ -2,6 +2,7 @@ from json import dumps from pathlib import Path import pkg_resources +from time import localtime from typing import Any, Dict, IO, Iterator, List, Optional, Set, Tuple, Type, Union from .fileutils import FileStream @@ -9,6 +10,9 @@ from .magic import MagicMatcher, MatchContext __version__: str = pkg_resources.require("polyfile")[0].version +mod_year = localtime(Path(__file__).stat().st_mtime).tm_year +__copyright__: str = f"Copyright ©{mod_year} Trail of Bits" +__license__: str = "Apache License Version 2.0 https://www.apache.org/licenses/" CUSTOM_MATCHERS: Dict[str, Type["Match"]] = {} diff --git a/polyfile/wildcards.py b/polyfile/wildcards.py new file mode 100644 index 00000000..91f411e3 --- /dev/null +++ b/polyfile/wildcards.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod +from collections.abc import Container +import re +from typing import Iterable, List + + +class Wildcard(ABC): + @abstractmethod + def match(self, to_test: str) -> bool: + raise NotImplementedError() + + def is_contained_in(self, items: Iterable[str]) -> bool: + for item in items: + if self.match(item): + return True + return False + + @staticmethod + def parse(to_test: str) -> "Wildcard": + if "*" in to_test or "?" in to_test: + return SimpleWildcard(to_test) + else: + return ConstantMatch(to_test) + + +class ConstantMatch(Wildcard): + def __init__(self, to_match: str): + self.to_match: str = to_match + + def match(self, to_test: str) -> bool: + return self.to_match == to_test + + def is_contained_in(self, items: Iterable[str]) -> bool: + if isinstance(items, Container): + return self.to_match in items + return super().is_contained_in(items) + + +class SimpleWildcard(Wildcard): + def __init__(self, pattern: str): + self.raw_pattern: str = pattern + self.pattern = re.compile(self.escaped_pattern) + + @property + def escaped_pattern(self) -> str: + components: List[str] = [] + component: str = "" + for last_c, c in zip((None,) + tuple(self.raw_pattern), self.raw_pattern): + escaped = last_c is not None and last_c == "\\" + if not escaped and (c == "*" or c == "?"): + if component: + components.append(re.escape(component)) + component = "" + components.append(f".{c}") + else: + component = f"{component}{c}" + if component: + components.append(re.escape(component)) + return "".join(components) + + def match(self, to_test: str) -> bool: + return bool(self.pattern.match(to_test))