Skip to content

Commit

Permalink
Support module wildcards everywhere (#235)
Browse files Browse the repository at this point in the history
Add support for wildcards in forbidden and independence contracts
  • Loading branch information
fbinz authored Sep 11, 2024
1 parent d3efc54 commit bb2c116
Show file tree
Hide file tree
Showing 17 changed files with 617 additions and 183 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Contributors
* piglei - https://github.com/piglei
* Anton Gruebel - https://github.com/gruebel
* Peter Byfield - https://github.com/Peter554
* Fabian Binz - https://github.com/fbinz
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

latest
------

* Add support for wildcards in forbidden and independence contracts.


2.0 (2024-1-9)
--------------

Expand Down
31 changes: 17 additions & 14 deletions docs/contract_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Contract types
==============


.. _forbidden modules:

Forbidden modules
Expand Down Expand Up @@ -59,8 +58,6 @@ External packages may also be forbidden.
**Configuration options**

Configuration options:

- ``source_modules``: A list of modules that should not import the forbidden modules.
- ``forbidden_modules``: A list of modules that should not be imported by the source modules. These may include
root level external packages (i.e. ``django``, but not ``django.db.models``). If external packages are included,
Expand Down Expand Up @@ -335,14 +332,26 @@ Options used by multiple contracts

- ``ignore_imports``: Optional list of imports, each in the form ``mypackage.foo.importer -> mypackage.bar.imported``.
These imports will be ignored: if the import would cause a contract to be broken, adding it to the list will cause the
contract be kept instead.
contract be kept instead. Supports :ref:`wildcards`.

- ``unmatched_ignore_imports_alerting``: The alerting level for handling expressions supplied in ``ignore_imports``
that do not match any imports in the graph. Choices are:

Wildcards are supported. ``*`` stands in for a module name, without including subpackages. ``**`` includes
subpackages too.
- ``error``: Error if there are any unmatched expressions (default).
- ``warn``: Print a warning for each unmatched expression.
- ``none``: Do not alert.

Note that this wildcard format is only supported for the ``ignore_imports`` fields. It can't currently be used for
other fields, such as in the ``source_modules`` field of a :ref:`forbidden modules` contract.
.. _wildcards:

Wildcards
---------

Wildcards are supported in most places where a module name is required to express a set of modules.
``*`` stands in for a module name, without including subpackages. ``**`` includes subpackages too.

Note that at the moment, layer contracts only support wildcards in `illegal_imports`.
If you have a use case for this, please file an issue.

Examples:

- ``mypackage.*``: matches ``mypackage.foo`` but not ``mypackage.foo.bar``.
Expand All @@ -352,9 +361,3 @@ Options used by multiple contracts
- ``mypackage.**.qux``: matches ``mypackage.foo.bar.qux`` and ``mypackage.foo.bar.baz.qux``.
- ``mypackage.foo*``: not a valid expression. (The wildcard must replace a whole module name.)

- ``unmatched_ignore_imports_alerting``: The alerting level for handling expressions supplied in ``ignore_imports``
that do not match any imports in the graph. Choices are:

- ``error``: Error if there are any unmatched expressions (default).
- ``warn``: Print a warning for each unmatched expression.
- ``none``: Do not alert.
1 change: 1 addition & 0 deletions src/importlinter/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for public-facing Python functions.
"""

from __future__ import annotations

from importlinter.application import use_cases
Expand Down
1 change: 1 addition & 0 deletions src/importlinter/application/contract_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
from typing import List, Optional, Sequence, Set


from importlinter.domain import helpers
from importlinter.domain.helpers import MissingImport
from importlinter.domain.imports import ImportExpression
Expand Down
1 change: 1 addition & 0 deletions src/importlinter/contracts/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
relying on it for a custom contract type, be aware things may change
without warning.
"""

from __future__ import annotations

import itertools
Expand Down
60 changes: 42 additions & 18 deletions src/importlinter/contracts/forbidden.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, cast
from typing import Iterable, List, cast

from grimp import ImportGraph

Expand All @@ -9,6 +9,7 @@
from importlinter.configuration import settings
from importlinter.domain import fields
from importlinter.domain.contract import Contract, ContractCheck
from importlinter.domain.helpers import module_expressions_to_modules
from importlinter.domain.imports import Module

from ._common import format_line_numbers
Expand All @@ -33,8 +34,8 @@ class ForbiddenContract(Contract):

type_name = "forbidden"

source_modules = fields.ListField(subfield=fields.ModuleField())
forbidden_modules = fields.ListField(subfield=fields.ModuleField())
source_modules = fields.ListField(subfield=fields.ModuleExpressionField())
forbidden_modules = fields.ListField(subfield=fields.ModuleExpressionField())
ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False)
allow_indirect_imports = fields.BooleanField(required=False, default=False)
unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR)
Expand All @@ -49,16 +50,30 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
unmatched_alerting=self.unmatched_ignore_imports_alerting, # type: ignore
)

self._check_all_modules_exist_in_graph(graph)
self._check_external_forbidden_modules()
source_modules = list(
module_expressions_to_modules(
graph,
self.source_modules, # type: ignore
)
)
forbidden_modules = list(
module_expressions_to_modules(
graph,
self.forbidden_modules, # type: ignore
)
)

self._check_all_modules_exist_in_graph(source_modules, graph)
self._check_external_forbidden_modules(forbidden_modules)

# We only need to check for illegal imports for forbidden modules that are in the graph.
forbidden_modules_in_graph = [
m for m in self.forbidden_modules if m.name in graph.modules # type: ignore
]
forbidden_modules_in_graph = [m for m in forbidden_modules if m.name in graph.modules]

def sort_key(module):
return module.name

for source_module in self.source_modules: # type: ignore
for forbidden_module in forbidden_modules_in_graph:
for source_module in sorted(source_modules, key=sort_key):
for forbidden_module in sorted(forbidden_modules_in_graph, key=sort_key):
output.verbose_print(
verbose,
"Searching for import chains from "
Expand Down Expand Up @@ -95,7 +110,7 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
"line_numbers": line_numbers,
}
)
subpackage_chain_data["chains"].append(chain_data)
subpackage_chain_data["chains"].append(chain_data) # type: ignore
if subpackage_chain_data["chains"]:
invalid_chains.append(subpackage_chain_data)
if verbose:
Expand All @@ -106,8 +121,15 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
f"in {timer.duration_in_s}s.",
)

# Sorting by upstream and downstream module ensures that the output is deterministic
# and that the same upstream and downstream modules are always adjacent in the output.
def chain_sort_key(chain_data):
return (chain_data["upstream_module"], chain_data["downstream_module"])

return ContractCheck(
kept=is_kept, warnings=warnings, metadata={"invalid_chains": invalid_chains}
kept=is_kept,
warnings=warnings,
metadata={"invalid_chains": sorted(invalid_chains, key=chain_sort_key)},
)

def render_broken_contract(self, check: "ContractCheck") -> None:
Expand All @@ -133,13 +155,15 @@ def render_broken_contract(self, check: "ContractCheck") -> None:

output.new_line()

def _check_all_modules_exist_in_graph(self, graph: ImportGraph) -> None:
for module in self.source_modules: # type: ignore
def _check_all_modules_exist_in_graph(
self, modules: Iterable[Module], graph: ImportGraph
) -> None:
for module in modules:
if module.name not in graph.modules:
raise ValueError(f"Module '{module.name}' does not exist.")

def _check_external_forbidden_modules(self) -> None:
external_forbidden_modules = self._get_external_forbidden_modules()
def _check_external_forbidden_modules(self, forbidden_modules) -> None:
external_forbidden_modules = self._get_external_forbidden_modules(forbidden_modules)
if external_forbidden_modules:
if self._graph_was_built_with_externals():
for module in external_forbidden_modules:
Expand All @@ -154,11 +178,11 @@ def _check_external_forbidden_modules(self) -> None:
"when there are external forbidden modules."
)

def _get_external_forbidden_modules(self) -> set[Module]:
def _get_external_forbidden_modules(self, forbidden_modules) -> set[Module]:
root_packages = [Module(name) for name in self.session_options["root_packages"]]
return {
forbidden_module
for forbidden_module in cast(List[Module], self.forbidden_modules)
for forbidden_module in cast(List[Module], forbidden_modules)
if not any(
forbidden_module.is_in_package(root_package) for root_package in root_packages
)
Expand Down
19 changes: 13 additions & 6 deletions src/importlinter/contracts/independence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
from importlinter.application.contract_utils import AlertLevel
from importlinter.domain import fields
from importlinter.domain.contract import Contract, ContractCheck
from importlinter.domain.helpers import module_expressions_to_modules
from importlinter.domain.imports import Module

from ._common import DetailedChain, Link, build_detailed_chain_from_route, render_chain_data
from ._common import (
DetailedChain,
Link,
build_detailed_chain_from_route,
render_chain_data,
)


class _SubpackageChainData(TypedDict):
Expand Down Expand Up @@ -41,7 +47,7 @@ class IndependenceContract(Contract):

type_name = "independence"

modules = fields.ListField(subfield=fields.ModuleField())
modules = fields.ListField(subfield=fields.ModuleExpressionField())
ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False)
unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR)

Expand All @@ -52,11 +58,12 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
unmatched_alerting=self.unmatched_ignore_imports_alerting, # type: ignore
)

self._check_all_modules_exist_in_graph(graph)
modules = list(module_expressions_to_modules(graph, self.modules)) # type: ignore
self._check_all_modules_exist_in_graph(graph, modules)

dependencies = graph.find_illegal_dependencies_for_layers(
# A single layer consisting of siblings.
layers=({module.name for module in self.modules},), # type: ignore
layers=({module.name for module in modules},),
)
invalid_chains = self._build_invalid_chains(dependencies, graph)

Expand All @@ -81,8 +88,8 @@ def render_broken_contract(self, check: "ContractCheck") -> None:

output.new_line()

def _check_all_modules_exist_in_graph(self, graph: ImportGraph) -> None:
for module in self.modules: # type: ignore
def _check_all_modules_exist_in_graph(self, graph: ImportGraph, modules) -> None:
for module in modules:
if module.name not in graph.modules:
raise ValueError(f"Module '{module.name}' does not exist.")

Expand Down
56 changes: 36 additions & 20 deletions src/importlinter/domain/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Generic, Iterable, List, Set, Type, TypeVar, Union, cast

from importlinter.domain.imports import ImportExpression, Module
from importlinter.domain.imports import ImportExpression, Module, ModuleExpression

FieldValue = TypeVar("FieldValue")

Expand Down Expand Up @@ -34,7 +34,6 @@ def __init__(
required: Union[bool, Type[NotSupplied]] = NotSupplied,
default: Union[FieldValue, Type[NotSupplied]] = NotSupplied,
) -> None:

if default is NotSupplied:
if required is NotSupplied:
self.required = True
Expand Down Expand Up @@ -159,6 +158,37 @@ def parse(self, raw_data: Union[str, List]) -> Module:
return Module(StringField().parse(raw_data))


class ModuleExpressionField(Field):
"""
A field for ModuleExpressions.
Accepts strings in the form:
"mypackage.foo.importer"
"mypackage.foo.*"
"mypackage.*.importer"
"mypackage.**"
"""

def parse(self, expression: Union[str, List[str]]) -> ModuleExpression:
if isinstance(expression, list):
raise ValidationError("Expected a single value, got multiple values.")

last_wildcard = None
for part in expression.split("."):
if "**" == last_wildcard and ("*" == part or "**" == part):
raise ValidationError("A recursive wildcard cannot be followed by a wildcard.")
if "*" == last_wildcard and "**" == part:
raise ValidationError("A wildcard cannot be followed by a recursive wildcard.")
if "*" == part or "**" == part:
last_wildcard = part
continue
if "*" in part:
raise ValidationError("A wildcard can only replace a whole module.")
last_wildcard = None

return ModuleExpression(expression)


class ImportExpressionField(Field):
"""
A field for ImportExpressions.
Expand All @@ -181,24 +211,10 @@ def parse(self, raw_data: Union[str, List]) -> ImportExpression:
if not (importer and imported):
raise ValidationError('Must be in the form "package.importer -> package.imported".')

self._validate_wildcard(importer)
self._validate_wildcard(imported)

return ImportExpression(importer=importer, imported=imported)

def _validate_wildcard(self, expression: str) -> None:
last_wildcard = None
for part in expression.split("."):
if "**" == last_wildcard and ("*" == part or "**" == part):
raise ValidationError("A recursive wildcard cannot be followed by a wildcard.")
if "*" == last_wildcard and "**" == part:
raise ValidationError("A wildcard cannot be followed by a recursive wildcard.")
if "*" == part or "**" == part:
last_wildcard = part
continue
if "*" in part:
raise ValidationError("A wildcard can only replace a whole module.")
last_wildcard = None
return ImportExpression(
importer=ModuleExpressionField().parse(importer),
imported=ModuleExpressionField().parse(imported),
)


class EnumField(Field):
Expand Down
Loading

0 comments on commit bb2c116

Please sign in to comment.