Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't assume object hashes are unique in caching logic #215

Merged
merged 10 commits into from
Dec 19, 2024
24 changes: 19 additions & 5 deletions src/tyro/_unsafe_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import sys
from typing import Any, Callable, Dict, List, TypeVar

CallableType = TypeVar("CallableType", bound=Callable)
Expand All @@ -23,11 +24,20 @@ def unsafe_cache(maxsize: int) -> Callable[[CallableType], CallableType]:
def inner(f: CallableType) -> CallableType:
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
key = tuple(unsafe_hash(arg) for arg in args) + tuple(
("__kwarg__", k, unsafe_hash(v)) for k, v in kwargs.items()
key = tuple(_make_key(arg) for arg in args) + tuple(
("__kwarg__", k, _make_key(v)) for k, v in kwargs.items()
)

if key in local_cache:
# Fuzzy check for cache collisions if called from a pytest test.
if "pytest" in sys.modules:
import random

if random.random() < 0.5:
a = f(*args, **kwargs)
b = local_cache[key]
assert a == b or str(a) == str(b)

return local_cache[key]

out = f(*args, **kwargs)
Expand All @@ -41,8 +51,12 @@ def wrapped_f(*args, **kwargs):
return inner


def unsafe_hash(obj: Any) -> Any:
def _make_key(obj: Any) -> Any:
"""Some context: https://github.com/brentyi/tyro/issues/214"""
try:
return hash(obj)
# If the object is hashable, we can use it as a key directly.
hash(obj)
return obj
except TypeError:
return id(obj)
# If the object is not hashable, we'll use assume the type/id are unique...
return type(obj), id(obj)
47 changes: 46 additions & 1 deletion tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
)

import pytest
from typing_extensions import Annotated, Final, Literal, TypeAlias
from typing_extensions import (
Annotated,
Final,
Literal,
Protocol,
TypeAlias,
runtime_checkable,
)

import tyro

Expand Down Expand Up @@ -953,3 +960,41 @@ class NumericTower:
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))


def test_runtime_checkable_edge_case() -> None:
"""From Kevin Black: https://github.com/brentyi/tyro/issues/214"""

@runtime_checkable
class DummyProtocol(Protocol):
pass

@dataclasses.dataclass(frozen=True)
class SubConfigA:
pass

@dataclasses.dataclass(frozen=True)
class SubConfigB:
pass

@dataclasses.dataclass
class Config:
subconfig: DummyProtocol

CONFIGS = {
"a": Config(subconfig=SubConfigA()),
"b": Config(subconfig=SubConfigB()),
}

assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["a"]
).subconfig
== SubConfigA()
)
assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["b"]
).subconfig
== SubConfigB()
)
40 changes: 40 additions & 0 deletions tests/test_py311_generated/test_dcargs_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
List,
Literal,
Optional,
Protocol,
Text,
Tuple,
TypeAlias,
TypeVar,
runtime_checkable,
)

import pytest
Expand Down Expand Up @@ -955,3 +957,41 @@ class NumericTower:
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))


def test_runtime_checkable_edge_case() -> None:
"""From Kevin Black: https://github.com/brentyi/tyro/issues/214"""

@runtime_checkable
class DummyProtocol(Protocol):
pass

@dataclasses.dataclass(frozen=True)
class SubConfigA:
pass

@dataclasses.dataclass(frozen=True)
class SubConfigB:
pass

@dataclasses.dataclass
class Config:
subconfig: DummyProtocol

CONFIGS = {
"a": Config(subconfig=SubConfigA()),
"b": Config(subconfig=SubConfigB()),
}

assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["a"]
).subconfig
== SubConfigA()
)
assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["b"]
).subconfig
== SubConfigB()
)
11 changes: 6 additions & 5 deletions tests/test_py311_generated/test_unsafe_cache_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@ def f(dummy: int):
nonlocal x
x += 1

# >= is because of fuzz testing inside of unsafe_cache
f(0)
f(0)
f(0)
assert x == 1
assert x >= 1
f(1)
f(1)
f(1)
assert x == 2
assert x >= 2
f(0)
f(0)
f(0)
assert x == 2
assert x >= 2
f(2)
f(2)
f(2)
assert x == 3
assert x >= 3
f(0)
f(0)
f(0)
assert x == 4
assert x >= 4
11 changes: 6 additions & 5 deletions tests/test_unsafe_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@ def f(dummy: int):
nonlocal x
x += 1

# >= is because of fuzz testing inside of unsafe_cache
f(0)
f(0)
f(0)
assert x == 1
assert x >= 1
f(1)
f(1)
f(1)
assert x == 2
assert x >= 2
f(0)
f(0)
f(0)
assert x == 2
assert x >= 2
f(2)
f(2)
f(2)
assert x == 3
assert x >= 3
f(0)
f(0)
f(0)
assert x == 4
assert x >= 4
Loading