Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Macro Type Checking #11147

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241213-171214.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add experimental macro type checking under --type-check flag.
time: 2024-12-13T17:12:14.984956-05:00
custom:
Author: peterallenwebb
Issue: "11090"
1 change: 1 addition & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def global_flags(func):
@p.warn_error_options
@p.write_json
@p.use_fast_test_edges
@p.type_check
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
Expand Down
7 changes: 7 additions & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,10 @@ def _version_callback(ctx, _param, value):
default=False,
hidden=True,
)

type_check = click.option(
"--type-check/--no-type-check",
envvar="DBT_TYPE_CHECK",
default=False,
hidden=True,
)
Empty file.
43 changes: 21 additions & 22 deletions core/dbt/clients/jinja_static.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import typing
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import jinja2

from dbt.artifacts.resources import RefArgs
from dbt.exceptions import MacroNamespaceNotStringError, ParsingError
from dbt_common.clients.jinja import get_environment
from dbt_common.clients.jinja_macro_call import DbtMacroCall
from dbt_common.exceptions.macros import MacroNameNotStringError
from dbt_common.tests import test_caching_enabled
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from dbt.context.providers import ParseDatabaseWrapper


Expand All @@ -31,7 +31,7 @@ def statically_extract_has_name_this(source: str) -> bool:

def statically_extract_macro_calls(
source: str, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"] = None
) -> List[str]:
) -> List[DbtMacroCall]:
# set 'capture_macros' to capture undefined
env = get_environment(None, capture_macros=True)

Expand All @@ -48,11 +48,11 @@ def statically_extract_macro_calls(
setattr(parsed, "_dbt_cached_calls", func_calls)

standard_calls = ["source", "ref", "config"]
possible_macro_calls = []
possible_macro_calls: List[DbtMacroCall] = []
for func_call in func_calls:
func_name = None
macro_call: Optional[DbtMacroCall] = None
if hasattr(func_call, "node") and hasattr(func_call.node, "name"):
func_name = func_call.node.name
macro_call = DbtMacroCall.from_call(func_call, func_call.node.name)
else:
if (
hasattr(func_call, "node")
Expand All @@ -72,34 +72,31 @@ def statically_extract_macro_calls(
# This skips calls such as adapter.parse_index
continue
else:
func_name = f"{package_name}.{macro_name}"
macro_call = DbtMacroCall.from_call(func_call, f"{package_name}.{macro_name}")
else:
continue
if not func_name:
continue
if func_name in standard_calls:
continue
elif ctx.get(func_name):

if not macro_call or macro_call.name in standard_calls or ctx.get(macro_call.name):
continue
else:
if func_name not in possible_macro_calls:
possible_macro_calls.append(func_name)

possible_macro_calls.append(macro_call)

return possible_macro_calls


def statically_parse_adapter_dispatch(
func_call, ctx: Dict[str, Any], db_wrapper: Optional["ParseDatabaseWrapper"]
) -> List[str]:
possible_macro_calls = []
) -> List[DbtMacroCall]:
possible_macro_calls: List[DbtMacroCall] = []
# This captures an adapter.dispatch('<macro_name>') call.

func_name = None
# macro_name positional argument
if len(func_call.args) > 0:
func_name = func_call.args[0].value

if func_name:
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))

# packages positional argument
macro_namespace = None
Expand All @@ -118,7 +115,7 @@ def statically_parse_adapter_dispatch(
# This will remain to enable static resolution
if type(kwarg.value).__name__ == "Const":
func_name = kwarg.value.value
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))
else:
raise MacroNameNotStringError(kwarg_value=kwarg.value.value)
elif kwarg.key == "macro_namespace":
Expand All @@ -143,14 +140,16 @@ def statically_parse_adapter_dispatch(
if db_wrapper:
macro = db_wrapper.dispatch(func_name, macro_namespace=macro_namespace).macro
func_name = f"{macro.package_name}.{macro.name}" # type: ignore[attr-defined]
possible_macro_calls.append(func_name)
possible_macro_calls.append(DbtMacroCall.from_call(func_call, func_name))
else: # this is only for tests/unit/test_macro_calls.py
if macro_namespace:
packages = [macro_namespace]
else:
packages = []
for package_name in packages:
possible_macro_calls.append(f"{package_name}.{func_name}")
possible_macro_calls.append(
DbtMacroCall.from_call(func_call, f"{package_name}.{func_name}")
)

return possible_macro_calls

Expand Down
25 changes: 16 additions & 9 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def load_and_parse_macros(self, project_parser_files):

self.build_macro_resolver()
# Look at changed macros and update the macro.depends_on.macros
self.macro_depends_on()
self.analyze_macros()

# Parse the files in the 'parser_files' dictionary, for parsers listed in
# 'parser_types'
Expand Down Expand Up @@ -776,32 +776,39 @@ def build_macro_resolver(self):
self.manifest.macros, self.root_project.project_name, internal_package_names
)

# Loop through macros in the manifest and statically parse
# the 'macro_sql' to find depends_on.macros
def macro_depends_on(self):
def analyze_macros(self):
"""Loop through macros in the manifest and statically parse the
'macro_sql' to find and set the value of depends_on.macros. Also,
perform type checking if flag is set.
"""
macro_ctx = generate_macro_context(self.root_project)
macro_namespace = TestMacroNamespace(self.macro_resolver, {}, None, MacroStack(), [])
adapter = get_adapter(self.root_project)
db_wrapper = ParseProvider().DatabaseWrapper(adapter, macro_namespace)
type_check = get_flags().TYPE_CHECK
for macro in self.manifest.macros.values():
if macro.created_at < self.started_at:
continue
possible_macro_calls = statically_extract_macro_calls(
macro.macro_sql, macro_ctx, db_wrapper
)
for macro_name in possible_macro_calls:
for macro_call in possible_macro_calls:
# adapter.dispatch calls can generate a call with the same name as the macro
# it ought to be an adapter prefix (postgres_) or default_
macro_name = macro_call.name
if macro_name == macro.name:
continue
package_name = macro.package_name
if "." in macro_name:
package_name, macro_name = macro_name.split(".")
dep_macro_id = self.macro_resolver.get_macro_id(package_name, macro_name)
if dep_macro_id:
macro.depends_on.add_macro(dep_macro_id) # will check for dupes
dep_macro = self.macro_resolver.get_macro(package_name, macro_name)
if dep_macro is not None and dep_macro.unique_id:
macro.depends_on.add_macro(dep_macro.unique_id) # will check for dupes

if type_check:
macro_call.check(dep_macro)

def write_manifest_for_partial_parse(self):
def write_manifest_for_partial_parse(self) -> None:
path = os.path.join(self.root_project.project_target_path, PARTIAL_PARSE_FILE_NAME)
try:
# This shouldn't be necessary, but we have gotten bug reports (#3757) of the
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/clients/test_jinja_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
)
from dbt.context.base import generate_base_context
from dbt.exceptions import ParsingError
from dbt_common.clients.jinja import MacroType
from dbt_common.clients.jinja_macro_call import DbtMacroCall


@pytest.mark.parametrize(
Expand Down Expand Up @@ -58,6 +60,34 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls):
ctx = generate_base_context(cli_vars)

possible_macro_calls = statically_extract_macro_calls(macro_string, ctx)
assert [c.name for c in possible_macro_calls] == expected_possible_macro_calls


@pytest.mark.parametrize(
"macro_string,expected_possible_macro_calls",
[
(
"{% macro parent_macro() %} {% do nested_macro(12, 'string', 1.0, True, my_kwarg = 2) %} {% endmacro %}",
[
DbtMacroCall(
name="nested_macro",
source="",
arg_types=[
MacroType(name="int"),
MacroType(name="str"),
MacroType(name="float"),
MacroType(name="bool"),
],
kwarg_types={
"my_kwarg": MacroType(name="int"),
},
),
],
)
],
)
def test_extract_macro_calls_with_type_info(macro_string, expected_possible_macro_calls):
possible_macro_calls = statically_extract_macro_calls(macro_string, {})
assert possible_macro_calls == expected_possible_macro_calls


Expand Down
Loading