diff --git a/src/rollbot/interpreter/calculator.py b/src/rollbot/interpreter/calculator.py index 8612cc7..376d967 100644 --- a/src/rollbot/interpreter/calculator.py +++ b/src/rollbot/interpreter/calculator.py @@ -4,7 +4,6 @@ from collections.abc import Callable from datetime import datetime, timedelta from functools import wraps -from math import factorial from random import SystemRandom from typing import TypeVar @@ -14,24 +13,13 @@ from rollbot.varenv import VarEnv +from .functions import funcs from .parser import parser, reconstructor random = SystemRandom() logger = structlog.get_logger() -def comb(n, k, *args): - return factorial(n) // factorial(k) // factorial(n - k) - - -def numerical_any(*args): - return 1 if any(a > 0 for a in args) else 0 - - -def one(*args): - return 1 - - def flatten(a): if isinstance(a, list): return functools.reduce(operator.iadd, (flatten(b) for b in a), []) @@ -39,7 +27,6 @@ def flatten(a): random = SystemRandom() -funcs = {"max": max, "min": min, "fac": factorial, "comb": comb, "any": numerical_any} class EvaluationError(Exception): @@ -187,16 +174,19 @@ def pick(self, args) -> tuple[int, str]: return c[0], f"{{{c[1]}}}" @visit_children_decor - def func(self, name: str, args) -> tuple[int, str]: + def func(self, name: str, args: list[tuple[int, str]] | None = None) -> tuple[int, str]: + if not args: + raise EvaluationError(f"Function '{name}' requires arguments") + self.count_complexity(10) args = flatten(args) self.count_complexity(len(args)) if name not in funcs: raise EvaluationError(f"No such function '{name}'") - values = (a[0] for a in args) + values = [a[0] for a in args] descr = ", ".join(a[1] for a in args) - return funcs.get(name, one)(*values), f"{name}({descr})" + return funcs[name](values), f"{name}({descr})" @visit_children_decor def var(self, name: str) -> tuple[int, str]: @@ -359,9 +349,9 @@ def func(self, name: str, args): # We have a nice selection of functions that are min/max preserving - minvalues = (a[0] for a in args) - maxvalues = (a[1] for a in args) - return funcs.get(name, one)(*minvalues), funcs.get(name, one)(*maxvalues) + minvalues = [a[0] for a in args] + maxvalues = [a[1] for a in args] + return funcs[name](minvalues), funcs[name](maxvalues) @visit_children_decor def var(self, name: str): @@ -465,7 +455,7 @@ def func(self, name: str, args): if name not in funcs: raise EvaluationError(f"No such function '{name}'") - return funcs.get(name, one)(*(a for a in flatten(args))) + return funcs.get[name](*(a for a in flatten(args))) def distribute(text: str, timeout: timedelta | None = None, env: VarEnv = None, num_bins: int = 1000): diff --git a/src/rollbot/interpreter/functions.py b/src/rollbot/interpreter/functions.py new file mode 100644 index 0000000..a970633 --- /dev/null +++ b/src/rollbot/interpreter/functions.py @@ -0,0 +1,35 @@ +from collections.abc import Callable +from math import factorial + + +def fn_max(args: list[int]) -> int: + return max(args) + + +def fn_min(args: list[int]) -> int: + return min(args) + + +def fn_fac(args: list[int]) -> int: + if args: + return factorial(args[0]) + return 1 + + +def fn_comb(args: list[int]) -> int: + if len(args) >= 2: + return factorial(args[0]) // factorial(args[1]) // factorial(args[0] - args[1]) + return 0 + + +def fn_any(args: list[int]) -> int: + return 1 if any(args) else 0 + + +funcs: dict[str, Callable[[list[int]], int]] = { + "max": fn_max, + "min": fn_min, + "fac": fn_fac, + "comb": fn_comb, + "any": fn_any, +} diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..8431efb --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,17 @@ +from rollbot.interpreter.calculator import evaluate, EvaluationError +import pytest + + +def test_func_arity(): + evaluate("max(d20)") + evaluate("max(d20, d20)") + evaluate("min(d20)") + evaluate("min(d20, d20)") + evaluate("fac(5)") + evaluate("fac(5, 5)") + + with pytest.raises(EvaluationError): + evaluate("nosuchfun()") + + with pytest.raises(EvaluationError): + evaluate("nosuchfun(1)")