From b1a7249f3b376e9ad78ed2b8ea30249e5381c1aa Mon Sep 17 00:00:00 2001 From: Dino Viehland Date: Fri, 12 Jan 2024 16:04:19 -0800 Subject: [PATCH] Fix issue with patching staticmethod w/ staticmethod Summary: Fixes a small issue I noticed while working on D52525922. When we go to handle a value being patched we don't do the unwrapping of the static/class method. So this moves that logic into `update_thunk` and moves the `_PyClassLoader_ResolveFunction` unwrapping to only happen in the case where we're not returning a thunk. Reviewed By: carljm Differential Revision: D52527349 fbshipit-source-id: fbe49c2ced72c68b1467a607ad20a2987c90c36c --- cinderx/StaticPython/classloader.c | 23 +++++++++++-------- .../test_compiler/test_static/patch.py | 22 ++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/cinderx/StaticPython/classloader.c b/cinderx/StaticPython/classloader.c index 161a94ba89c..224c1450d27 100644 --- a/cinderx/StaticPython/classloader.c +++ b/cinderx/StaticPython/classloader.c @@ -2807,8 +2807,13 @@ update_thunk(_Py_StaticThunk *thunk, PyObject *previous, PyObject *new_value) { Py_CLEAR(thunk->thunk_tcs.tcs_value); if (new_value != NULL) { - thunk->thunk_tcs.tcs_value = new_value; - Py_INCREF(new_value); + PyObject *unwrapped_new = classloader_maybe_unwrap_callable(new_value); + if (unwrapped_new != NULL) { + thunk->thunk_tcs.tcs_value = unwrapped_new; + } else { + thunk->thunk_tcs.tcs_value = new_value; + Py_INCREF(new_value); + } } PyObject *funcref; if (new_value == previous) { @@ -4258,6 +4263,13 @@ _PyClassLoader_ResolveFunction(PyObject *path, PyObject **container) original = NULL; } + if (original != NULL) { + PyObject *res = (PyObject *)get_or_make_thunk(func, original, *container, containerkey); + Py_DECREF(func); + assert(res != NULL); + return res; + } + if (func != NULL) { if (Py_TYPE(func) == &PyStaticMethod_Type) { PyObject *res = Ci_PyStaticMethod_GetFunc(func); @@ -4273,13 +4285,6 @@ _PyClassLoader_ResolveFunction(PyObject *path, PyObject **container) } } - if (original != NULL) { - PyObject *res = (PyObject *)get_or_make_thunk(func, original, *container, containerkey); - Py_DECREF(func); - assert(res != NULL); - return res; - } - return func; } diff --git a/cinderx/test_cinderx/test_compiler/test_static/patch.py b/cinderx/test_cinderx/test_compiler/test_static/patch.py index 74ce7c28657..76ddfe54caa 100644 --- a/cinderx/test_cinderx/test_compiler/test_static/patch.py +++ b/cinderx/test_cinderx/test_compiler/test_static/patch.py @@ -257,6 +257,28 @@ def g(): with patch(f"{mod.__name__}.C.f", autospec=True, return_value=100) as p: self.assertEqual(g(), 100) + def test_patch_staticmethod_with_staticmethod(self): + codestr = """ + class C: + @staticmethod + def f(): + return 42 + + def g(): + return C.f() + """ + with self.in_module(codestr) as mod: + g = mod.g + for i in range(100): + self.assertEqual(g(), 42) + + @staticmethod + def new(): + return 100 + + mod.C.f = new + self.assertEqual(g(), 100) + def test_patch_static_function_non_autospec(self): codestr = """ class C: