diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index cdc069da..67daa3ff 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,5 +1,6 @@ import re from collections import namedtuple +from typing import Iterator, Optional, List from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -12,40 +13,40 @@ ) -def regex(pat): +def regex(pat: str) -> re.Pattern: return re.compile(pat, re.DOTALL | re.MULTILINE) class BlockData: """raw plaintext data from the top level of the file.""" - def __init__(self, contents): + def __init__(self, contents: str) -> None: self.block_type_name = "__dbt__data" - self.contents = contents + self.contents: str = contents self.full_block = contents class BlockTag: - def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw): + def __init__(self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None) -> None: self.block_type_name = block_type_name self.block_name = block_name self.contents = contents self.full_block = full_block - def __str__(self): + def __str__(self) -> str: return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name) - def __repr__(self): + def __repr__(self) -> str: return str(self) @property - def end_block_type_name(self): + def end_block_type_name(self) -> str: return "end{}".format(self.block_type_name) - def end_pat(self): + def end_pat(self) -> re.Pattern: # we don't want to use string formatting here because jinja uses most # of the string formatting operators in its syntax... - pattern = "".join( + pattern: str = "".join( ( r"(?P((?:\s*\{\%\-|\{\%)\s*", self.end_block_type_name, @@ -98,13 +99,11 @@ def end_pat(self): class TagIterator: - def __init__(self, text): - self.text = text - self.blocks = [] - self._parenthesis_stack = [] - self.pos = 0 + def __init__(self, text: str) -> None: + self.text: str = text + self.pos: int = 0 - def linepos(self, end=None) -> str: + def linepos(self, end: Optional[int] = None) -> str: """Given an absolute position in the input text, return a pair of line number + relative position to the start of the line. """ @@ -116,26 +115,22 @@ def linepos(self, end=None) -> str: line_number = text.count("\n") + 1 return f"{line_number}:{end_val - last_line_start}" - def advance(self, new_position): + def advance(self, new_position: int) -> None: self.pos = new_position - def rewind(self, amount=1): + def rewind(self, amount: int = 1) -> None: self.pos -= amount - def _search(self, pattern): + def _search(self, pattern: re.Pattern) -> Optional[re.Match]: return pattern.search(self.text, self.pos) - def _match(self, pattern): + def _match(self, pattern: re.Pattern) -> Optional[re.Match]: return pattern.match(self.text, self.pos) - def _first_match(self, *patterns, **kwargs): + def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore matches = [] for pattern in patterns: - # default to 'search', but sometimes we want to 'match'. - if kwargs.get("method", "search") == "search": - match = self._search(pattern) - else: - match = self._match(pattern) + match = self._search(pattern) if match: matches.append(match) if not matches: @@ -144,13 +139,13 @@ def _first_match(self, *patterns, **kwargs): # TODO: do I need to account for m.start(), or is this ok? return min(matches, key=lambda m: m.end()) - def _expect_match(self, expected_name, *patterns, **kwargs): - match = self._first_match(*patterns, **kwargs) + def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore + match = self._first_match(*patterns) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos:]) return match - def handle_expr(self, match): + def handle_expr(self, match: re.Match) -> None: """Handle an expression. At this point we're at a string like: {{ 1 + 2 }} ^ right here @@ -176,12 +171,12 @@ def handle_expr(self, match): self.advance(match.end()) - def handle_comment(self, match): + def handle_comment(self, match: re.Match) -> None: self.advance(match.end()) match = self._expect_match("#}", COMMENT_END_PATTERN) self.advance(match.end()) - def _expect_block_close(self): + def _expect_block_close(self) -> None: """Search for the tag close marker. To the right of the type name, there are a few possiblities: - a name (handled by the regex's 'block_name') @@ -203,13 +198,13 @@ def _expect_block_close(self): string_match = self._expect_match("string", STRING_PATTERN) self.advance(string_match.end()) - def handle_raw(self): + def handle_raw(self) -> int: # raw blocks are super special, they are a single complete regex match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) self.advance(match.end()) return match.end() - def handle_tag(self, match): + def handle_tag(self, match: re.Match) -> Tag: """The tag could be one of a few things: {% mytag %} @@ -234,7 +229,7 @@ def handle_tag(self, match): self._expect_block_close() return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos) - def find_tags(self): + def find_tags(self) -> Iterator[Tag]: while True: match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN) if match is None: @@ -259,7 +254,7 @@ def find_tags(self): "Invalid regex match in next_block, expected block start, " "expr start, or comment start" ) - def __iter__(self): + def __iter__(self) -> Iterator[Tag]: return self.find_tags() @@ -272,31 +267,31 @@ def __iter__(self): class BlockIterator: - def __init__(self, tag_iterator): + def __init__(self, tag_iterator: TagIterator) -> None: self.tag_parser = tag_iterator - self.current = None - self.stack = [] - self.last_position = 0 + self.current: Optional[Tag] = None + self.stack: List[str] = [] + self.last_position: int = 0 @property - def current_end(self): + def current_end(self) -> int: if self.current is None: return 0 else: return self.current.end @property - def data(self): + def data(self) -> str: return self.tag_parser.text - def is_current_end(self, tag): + def is_current_end(self, tag: Tag) -> bool: return ( tag.block_type_name.startswith("end") and self.current is not None and tag.block_type_name[3:] == self.current.block_type_name ) - def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + def find_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> Iterator[BlockData | BlockTag]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -347,5 +342,5 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True): if raw_data: yield BlockData(raw_data) - def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + def lex_for_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> List[BlockData | BlockTag]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))