Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject TagIterator into BlockIterator #39

Merged
merged 8 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240123-161107.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Inject TagIterator into BlockIterator for greater flexibility.
time: 2024-01-23T16:11:07.24321-05:00
custom:
Author: peterallenwebb
Issue: "38"
107 changes: 54 additions & 53 deletions dbt_common/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from typing import Iterator, List, Optional, Set, Union

from dbt_common.exceptions import (
BlockDefinitionNotAtTopError,
Expand All @@ -12,40 +13,42 @@
)


def regex(pat):
def regex(pat: str) -> re.Pattern:
return re.compile(pat, re.DOTALL | re.MULTILINE)


class BlockData:
"""raw plaintext data from the top level of the file."""

def __init__(self, contents):
def __init__(self, contents: str) -> None:
self.block_type_name = "__dbt__data"
self.contents = contents
self.contents: str = contents
self.full_block = contents


class BlockTag:
def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw):
def __init__(
self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None
) -> None:
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block

def __str__(self):
def __str__(self) -> str:
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)

def __repr__(self):
def __repr__(self) -> str:
return str(self)

@property
def end_block_type_name(self):
def end_block_type_name(self) -> str:
return "end{}".format(self.block_type_name)

def end_pat(self):
def end_pat(self) -> re.Pattern:
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = "".join(
pattern: str = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
Expand Down Expand Up @@ -98,44 +101,38 @@ def end_pat(self):


class TagIterator:
def __init__(self, data):
self.data = data
self.blocks = []
self._parenthesis_stack = []
self.pos = 0

def linepos(self, end=None) -> str:
"""Given an absolute position in the input data, return a pair of
def __init__(self, text: str) -> None:
self.text: str = text
self.pos: int = 0

def linepos(self, end: Optional[int] = None) -> str:
"""Given an absolute position in the input text, return a pair of
line number + relative position to the start of the line.
"""
end_val: int = self.pos if end is None else end
data = self.data[:end_val]
text = self.text[:end_val]
# if not found, rfind returns -1, and -1+1=0, which is perfect!
last_line_start = data.rfind("\n") + 1
last_line_start = text.rfind("\n") + 1
# it's easy to forget this, but line numbers are 1-indexed
line_number = data.count("\n") + 1
line_number = text.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"

def advance(self, new_position):
def advance(self, new_position: int) -> None:
self.pos = new_position

def rewind(self, amount=1):
def rewind(self, amount: int = 1) -> None:
self.pos -= amount

def _search(self, pattern):
return pattern.search(self.data, self.pos)
def _search(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.search(self.text, self.pos)

def _match(self, pattern):
return pattern.match(self.data, self.pos)
def _match(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.match(self.text, self.pos)

def _first_match(self, *patterns, **kwargs):
def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
match = self._search(pattern)
if match:
matches.append(match)
if not matches:
Expand All @@ -144,13 +141,13 @@ def _first_match(self, *patterns, **kwargs):
# TODO: do I need to account for m.start(), or is this ok?
return min(matches, key=lambda m: m.end())

def _expect_match(self, expected_name, *patterns, **kwargs):
match = self._first_match(*patterns, **kwargs)
def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore
match = self._first_match(*patterns)
if match is None:
raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :])
raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :])
return match

def handle_expr(self, match):
def handle_expr(self, match: re.Match) -> None:
"""Handle an expression. At this point we're at a string like:
{{ 1 + 2 }}
^ right here
Expand All @@ -176,12 +173,12 @@ def handle_expr(self, match):

self.advance(match.end())

def handle_comment(self, match):
def handle_comment(self, match: re.Match) -> None:
self.advance(match.end())
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())

def _expect_block_close(self):
def _expect_block_close(self) -> None:
"""Search for the tag close marker.
To the right of the type name, there are a few possiblities:
- a name (handled by the regex's 'block_name')
Expand All @@ -203,13 +200,13 @@ def _expect_block_close(self):
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())

def handle_raw(self):
def handle_raw(self) -> int:
# raw blocks are super special, they are a single complete regex
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()

def handle_tag(self, match):
def handle_tag(self, match: re.Match) -> Tag:
"""The tag could be one of a few things:

{% mytag %}
Expand All @@ -234,7 +231,7 @@ def handle_tag(self, match):
self._expect_block_close()
return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)

def find_tags(self):
def find_tags(self) -> Iterator[Tag]:
while True:
match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
if match is None:
Expand All @@ -259,7 +256,7 @@ def find_tags(self):
"Invalid regex match in next_block, expected block start, " "expr start, or comment start"
)

def __iter__(self):
def __iter__(self) -> Iterator[Tag]:
return self.find_tags()


Expand All @@ -272,31 +269,33 @@ def __iter__(self):


class BlockIterator:
def __init__(self, data):
self.tag_parser = TagIterator(data)
self.current = None
self.stack = []
self.last_position = 0
def __init__(self, tag_iterator: TagIterator) -> None:
self.tag_parser = tag_iterator
self.current: Optional[Tag] = None
self.stack: List[str] = []
self.last_position: int = 0

@property
def current_end(self):
def current_end(self) -> int:
if self.current is None:
return 0
else:
return self.current.end

@property
def data(self):
return self.tag_parser.data
def data(self) -> str:
return self.tag_parser.text

def is_current_end(self, tag):
def is_current_end(self, tag: Tag) -> bool:
return (
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)

def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
def find_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> Iterator[Union[BlockData, BlockTag]]:
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
Expand Down Expand Up @@ -347,5 +346,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
if raw_data:
yield BlockData(raw_data)

def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
def lex_for_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> List[Union[BlockData, BlockTag]]:
return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))
7 changes: 4 additions & 3 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_materialization_macro_name,
get_test_macro_name,
)
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator

from dbt_common.exceptions import (
CompilationError,
Expand Down Expand Up @@ -516,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:


def extract_toplevel_blocks(
data: str,
text: str,
allowed_blocks: Optional[Set[str]] = None,
collect_raw_data: bool = True,
) -> List[Union[BlockData, BlockTag]]:
Expand All @@ -534,4 +534,5 @@ def extract_toplevel_blocks(
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
tag_iterator = TagIterator(text)
return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)