Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Better f-string spacing #228

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
716 changes: 359 additions & 357 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ snakefmt = 'snakefmt.snakefmt:main'
[tool.poetry.dependencies]
python = "^3.8.1"
click = "^8.0.0"
black = "^24.1.1"
black = "^24.3.0"
toml = "^0.10.2"
importlib_metadata = {version = ">=1.7.0,<5.0", python = "<3.8"}

Expand Down
3 changes: 3 additions & 0 deletions snakefmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

__version__ = metadata.version("snakefmt")

# New f-string tokenizing was introduced in python 3.12 - we have to deal with it, too.
fstring_tokeniser_in_use = sys.version_info >= (3, 12)

DEFAULT_LINE_LENGTH = 88
DEFAULT_TARGET_VERSIONS = {
TargetVersion.PY38,
Expand Down
56 changes: 54 additions & 2 deletions snakefmt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,71 @@
"""

from dataclasses import fields
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import click
import toml
from black import Mode, find_project_root
from black import Mode

from snakefmt import DEFAULT_LINE_LENGTH, DEFAULT_TARGET_VERSIONS
from snakefmt.exceptions import MalformattedToml

PathLike = Union[Path, str]


@lru_cache
def find_project_root(
srcs: Sequence[str], stdin_filename: Optional[str] = None
) -> Tuple[Path, str]:
"""Return a directory containing .git, .hg, or pyproject.toml.

That directory will be a common parent of all files and directories
passed in `srcs`.

If no directory in the tree contains a marker that would specify it's the
project root, the root of the file system is returned.

Returns a two-tuple with the first element as the project root path and
the second element as a string describing the method by which the
project root was discovered.

Note: taken directly from black v24.1.0 as they changed the behaviour of this
function in v24.2.0 to only find the root if the pyproject.toml file contained the
[tool.black] section. This is not the desired behaviour for snakefmt
"""
if stdin_filename is not None:
srcs = tuple(stdin_filename if s == "-" else s for s in srcs)
if not srcs:
srcs = [str(Path.cwd().resolve())]

path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs]

# A list of lists of parents for each 'src'. 'src' is included as a
# "parent" of itself if it is a directory
src_parents = [
list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs
]

common_base = max(
set.intersection(*(set(parents) for parents in src_parents)),
key=lambda path: path.parts,
)

for directory in (common_base, *common_base.parents):
if (directory / ".git").exists():
return directory, ".git directory"

if (directory / ".hg").is_dir():
return directory, ".hg directory"

if (directory / "pyproject.toml").is_file():
return directory, "pyproject.toml"

return directory, "file system root"


def find_pyproject_toml(start_path: Sequence[str]) -> Optional[str]:
root, _ = find_project_root(start_path)
config_file = root / "pyproject.toml"
Expand Down
5 changes: 4 additions & 1 deletion snakefmt/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ParameterSyntax,
Vocabulary,
add_token_space,
fstring_processing,
is_newline,
re_add_curly_bracket_if_needed,
)
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(self, snakefile: TokenIterator):
self.last_block_was_snakecode = False
self.block_indent = 0
self.queriable = True
self.in_fstring = False

status = self.get_next_queriable(self.snakefile)
self.buffer = status.buffer
Expand Down Expand Up @@ -277,6 +279,7 @@ def get_next_queriable(self, snakefile: TokenIterator) -> Status:
prev_token: Optional[Token] = Token(tokenize.NAME)
while True:
token = next(snakefile)
self.in_fstring = fstring_processing(token, prev_token, self.in_fstring)
if block_indent == -1 and not_a_comment_related_token(token):
block_indent = self.cur_indent
if token.type == tokenize.INDENT:
Expand Down Expand Up @@ -317,7 +320,7 @@ def get_next_queriable(self, snakefile: TokenIterator) -> Status:
token, block_indent, self.cur_indent, buffer, False, pythonable
)

if add_token_space(prev_token, token):
if add_token_space(prev_token, token, self.in_fstring):
buffer += " "
prev_token = token
if newline:
Expand Down
61 changes: 51 additions & 10 deletions snakefmt/parser/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Code in charge of parsing and validating Snakemake syntax
"""

import sys
import tokenize
from abc import ABC, abstractmethod
from re import match as re_match
from typing import Optional

from snakefmt import fstring_tokeniser_in_use
from snakefmt.exceptions import (
ColonError,
EmptyContextError,
Expand Down Expand Up @@ -38,17 +39,31 @@
tokenize.NUMBER: {tokenize.NAME, tokenize.OP},
tokenize.OP: {tokenize.NAME, tokenize.STRING, tokenize.NUMBER, tokenize.OP},
}
# add fstring start to spacing_triggers if python 3.12 or higher
if hasattr(tokenize, "FSTRING_START"):

if fstring_tokeniser_in_use:
spacing_triggers[tokenize.NAME].add(tokenize.FSTRING_START)
spacing_triggers[tokenize.OP].add(tokenize.FSTRING_START)
# A more compact spacing syntax than the above.
fstring_spacing_triggers = {
tokenize.NAME: {
tokenize.NAME,
tokenize.STRING,
tokenize.NUMBER,
},
tokenize.STRING: {tokenize.NAME, tokenize.OP},
tokenize.NUMBER: {tokenize.NAME},
tokenize.OP: {
tokenize.NAME,
tokenize.STRING,
},
}


def re_add_curly_bracket_if_needed(token: Token) -> str:
result = ""
if (
token is not None
and sys.version_info >= (3, 12)
fstring_tokeniser_in_use
and token is not None
and token.type == tokenize.FSTRING_MIDDLE
):
if token.string.endswith("}"):
Expand All @@ -58,6 +73,22 @@ def re_add_curly_bracket_if_needed(token: Token) -> str:
return result


def fstring_processing(
token: Token, prev_token: Optional[Token], in_fstring: bool
) -> bool:
"""
Returns True if we are entering, or have already entered and not exited,
an f-string.
"""
result = False
if fstring_tokeniser_in_use:
if prev_token is not None and prev_token.type == tokenize.FSTRING_START:
result = True
elif token.type != tokenize.FSTRING_END and in_fstring:
result = True
return result


def operator_skip_spacing(prev_token: Token, token: Token) -> bool:
if prev_token.type != tokenize.OP and token.type != tokenize.OP:
return False
Expand All @@ -76,11 +107,14 @@ def operator_skip_spacing(prev_token: Token, token: Token) -> bool:
return False


def add_token_space(prev_token: Token, token: Token) -> bool:
def add_token_space(prev_token: Token, token: Token, in_fstring: bool = False) -> bool:
result = False
if prev_token is not None and prev_token.type in spacing_triggers:
if prev_token is not None:
if not operator_skip_spacing(prev_token, token):
if token.type in spacing_triggers[prev_token.type]:
if not in_fstring:
if token.type in spacing_triggers.get(prev_token.type, {}):
result = True
elif token.type in fstring_spacing_triggers.get(prev_token.type, {}):
result = True
return result

Expand Down Expand Up @@ -148,8 +182,8 @@ def has_a_key(self) -> bool:
def has_value(self) -> bool:
return len(self.value) > 0

def add_elem(self, prev_token: Token, token: Token):
if add_token_space(prev_token, token) and len(self.value) > 0:
def add_elem(self, prev_token: Token, token: Token, in_fstring: bool = False):
if add_token_space(prev_token, token, in_fstring) and len(self.value) > 0:
self.value += " "

if self.is_empty():
Expand Down Expand Up @@ -322,6 +356,7 @@ def __init__(
self.eof = False
self.incident_vocab = incident_vocab
self._brackets = list()
self.in_fstring = False
self.in_lambda = False
self.found_newline = False

Expand Down Expand Up @@ -378,6 +413,12 @@ def check_exit(self, cur_param: Parameter):

def process_token(self, cur_param: Parameter, prev_token: Token) -> Parameter:
token_type = self.token.type
# f-string treatment (since python 3.12)
self.in_fstring = fstring_processing(self.token, prev_token, self.in_fstring)
if self.in_fstring:
cur_param.add_elem(prev_token, self.token, self.in_fstring)
return cur_param

# Eager treatment of comments: tag them onto params
if token_type == tokenize.COMMENT and not self.in_brackets:
cur_param.add_comment(self.token.string, self.keyword_indent)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_black_and_snakefmt_default_line_lengths_aligned():
class TestFindPyprojectToml:
def test_find_pyproject_toml_nested_directory(self, tmp_path):
config_file = (tmp_path / "pyproject.toml").resolve()
config_file.touch()
# add [tool.black] to TOML to ensure black finds it (new in black v24.2.0)
config_file.write_text("[tool.black]\n")
dir1 = Path(tmp_path / "dir1").resolve()
dir1.mkdir()
snakefile = dir1 / "Snakefile"
Expand Down
35 changes: 27 additions & 8 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ def test_decorator_is_handled_correctly(self):
actual = formatter.get_formatted()
assert actual == snakecode

def test_f_strings(self):
def test_fstrings(self):
"""This is relevant for python3.12"""
snakecode = 'a = f"{1 + 2}" if 1 > 0 else f"{1 - 2}"\n'
snakecode = 'a = f"{1+2}" if 1 > 0 else f"{1-2}"\n'
formatter = setup_formatter(snakecode)

actual = formatter.get_formatted()
Expand Down Expand Up @@ -686,7 +686,7 @@ def test_keyword_with_tpq_inside_expression_left_alone(self):
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode

def test_rf_string_tpq_supported(self):
def test_r_and_fstring_tpq_supported(self):
"""Deliberately tests for consecutive r/f strings and with
single or double quotes"""
for preceding in {"r", "f"}:
Expand Down Expand Up @@ -846,7 +846,7 @@ def test_tpq_inside_run_block(self):

assert formatter.get_formatted() == snakecode

def test_f_string_with_double_braces_in_input(self):
def test_fstring_with_double_braces_in_input(self):
"""https://github.com/snakemake/snakefmt/issues/207"""
snakecode = (
"rule align:\n"
Expand All @@ -859,11 +859,8 @@ def test_f_string_with_double_braces_in_input(self):
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode

def test_f_string_with_double_braces_in_python_code(self):
def test_fstring_with_double_braces_in_python_code(self):
"""https://github.com/snakemake/snakefmt/issues/215"""
"""def get_test_regions(wildcards):
benchmark = config["variant-calls"][wildcards.callset]["benchmark"]
return f"resources/regions/{benchmark}/test-regions.cov-{{cov}}.bed"""
snakecode = (
"def get_test_regions(wildcards):\n"
f'{TAB * 1}benchmark = config["variant-calls"][wildcards.callset]["benchmark"]\n' # noqa: E501
Expand All @@ -872,6 +869,28 @@ def test_f_string_with_double_braces_in_python_code(self):
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode

def test_fstring_spacing_of_consecutive_braces(self):
"""https://github.com/snakemake/snakefmt/issues/222"""
snakecode = 'f"{var1}{var2}"\n'
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode

def test_fstring_with_equal_sign_inside_function_call(self):
"""https://github.com/snakemake/snakefmt/issues/220"""
snakecode = 'test = f"job_properties: {json.dumps(job_properties, indent=4)}"\n'
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode

def test_fstring_with_list_comprehension_inside_function_call(self):
"""https://github.com/snakemake/snakefmt/issues/227"""
snakecode = (
"rule subsample:\n"
f"{TAB * 1}input:\n"
f"{TAB * 2}f\"{{' '.join([i for i in range(10)])}}\",\n"
)
formatter = setup_formatter(snakecode)
assert formatter.get_formatted() == snakecode


class TestReformatting_SMK_BREAK:
"""
Expand Down
Loading