Skip to content

Commit

Permalink
Add command for running tests with -X jit-auto
Browse files Browse the repository at this point in the history
Summary:
As I'm working on AutoJIT + type profiling, I'm finding that there are other
test failures that relate purely to AutoJIT, so adding that as a supported way
of running tests.

Most of the test failures I was seeing were JIT tests that expected functions to
be compiled after being called once, this is not true with AutoJIT.  This diff
adds a `cinderjit` function to read what the current AutoJIT threshold is so we
can tell if we're running with it or not.  It'll be zero when AutoJIT is not
enabled.

There is still an outstanding failure with a StaticPython test that I've tasked
up in T165854755.

Reviewed By: mpage

Differential Revision: D50021611

fbshipit-source-id: 0e2d8a15112251e0969b34859fccf87e85afd06b
  • Loading branch information
Alex Malyshev authored and facebook-github-bot committed Oct 9, 2023
1 parent ab6c198 commit 25ce292
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 13 deletions.
9 changes: 9 additions & 0 deletions Jit/pyjit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ static PyObject* force_compile(PyObject* /* self */, PyObject* func) {
return nullptr;
}

static PyObject* auto_jit_threshold(PyObject* /* self */, PyObject*) {
return PyLong_FromLong(getConfig().auto_jit_threshold);
}

int _PyJIT_IsCompiled(PyObject* func) {
if (jit_ctx == nullptr) {
return 0;
Expand Down Expand Up @@ -1406,6 +1410,11 @@ static PyMethodDef jit_methods[] = {
METH_FASTCALL,
"Disable the jit."},
{"disassemble", disassemble, METH_O, "Disassemble JIT compiled functions"},
{"auto_jit_threshold",
auto_jit_threshold,
METH_NOARGS,
"Return the current AutoJIT threshold, only makes sense when the JIT is "
"enabled."},
{"is_jit_compiled",
is_jit_compiled,
METH_O,
Expand Down
44 changes: 35 additions & 9 deletions Lib/test/test_cinderjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ def f():
return 1
f()
# Depending on which JIT mode is being used, f might not have been
# compiled on the first call, but it will be after `force_compile`.
cinderjit.force_compile(f)
assert cinderjit.is_jit_compiled(f)
cinderjit.disable()
f()
"""
Expand Down Expand Up @@ -897,6 +902,8 @@ def test():
)
import tmp_b

cinderjit.force_compile(tmp_b.test)

self.assertEqual(tmp_b.test(), 3)
self.assertTrue(cinderjit.is_jit_compiled(tmp_b.test))
self.assertTrue(
Expand Down Expand Up @@ -1419,6 +1426,10 @@ def get_a():
cinderjit.clear_runtime_stats()
import tmp_a

# Force the compilation if this is running with AutoJIT.
if cinderjit:
cinderjit.force_compile(tmp_a.get_a)

# What happens on the first call is kinda undefined in principle
# given lazy imports; somebody could previously have imported B
# (not in this specific test, but in principle), or not, so the
Expand Down Expand Up @@ -1483,10 +1494,15 @@ def get_a():
),
encoding="utf8",
)

if cinderjit:
cinderjit.clear_runtime_stats()
import tmp_a

# Force the compilation if this is running with AutoJIT.
if cinderjit:
cinderjit.force_compile(tmp_a.get_a)

tmp_a.get_a()
self.assertEqual(tmp_a.get_a(), 5)
if cinderjit:
Expand Down Expand Up @@ -1524,6 +1540,10 @@ def get_a():
cinderjit.clear_runtime_stats()
import tmp_a

# Force the compilation if this is running with AutoJIT.
if cinderjit:
cinderjit.force_compile(tmp_a.get_a)

tmp_a.get_a()
self.assertEqual(tmp_a.get_a(), 5)
if cinderjit:
Expand Down Expand Up @@ -1795,9 +1815,9 @@ def __init__(self, x: bool) -> None:
"""
with self.in_module(codestr) as mod:
gc.immortalize_heap()
foo = mod.Foo(True)
if cinderjit:
self.assertTrue(cinderjit.is_jit_compiled(mod.Foo.__init__))
cinderjit.force_compile(mod.Foo.__init__)
foo = mod.Foo(True)

def test_restore_materialized_parent_pyframe_in_gen_throw(self):
# This reproduces a bug that causes the top frame in the shadow stack
Expand Down Expand Up @@ -1883,7 +1903,7 @@ async def main():
with self.assertRaises(asyncio.CancelledError):
asyncio.run(main())

if cinderjit:
if cinderjit and cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(a))
self.assertTrue(cinderjit.is_jit_compiled(b))
self.assertTrue(cinderjit.is_jit_compiled(c.__wrapped__))
Expand Down Expand Up @@ -2942,7 +2962,7 @@ async def call_x() -> None:
coro.send(None)
coro.close()
self.assertFalse(mod.x.last_awaited())
if cinderjit:
if cinderjit and cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(mod.await_x))
self.assertTrue(cinderjit.is_jit_compiled(mod.call_x))

Expand Down Expand Up @@ -2980,7 +3000,7 @@ async def call_x(x: X) -> None:
coro.send(None)
coro.close()
self.assertFalse(awaited_capturer.last_awaited())
if cinderjit:
if cinderjit and cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(mod.await_x))
self.assertTrue(cinderjit.is_jit_compiled(mod.call_x))

Expand Down Expand Up @@ -4081,7 +4101,7 @@ def testfunc():
testfunc = mod.testfunc
self.assertTrue(testfunc())

if cinderjit:
if cinderjit and cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(testfunc))


Expand Down Expand Up @@ -4119,7 +4139,9 @@ def g():
self.assertTrue(g())

self.assertFalse(cinderjit.is_jit_compiled(f))
self.assertTrue(cinderjit.is_jit_compiled(g))

if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(g))

@unittest.skipIf(
not cinderjit or not cinderjit.is_hir_inliner_enabled(),
Expand All @@ -4142,7 +4164,9 @@ def g():
self.assertTrue(g())

self.assertFalse(cinderjit.is_jit_compiled(f))
self.assertTrue(cinderjit.is_jit_compiled(g))

if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(g))

self.assertEqual(cinderjit.get_num_inlined_functions(g), 1)

Expand Down Expand Up @@ -5258,7 +5282,7 @@ def lme_test_func(self, flag=False):
def test_multiple_call_method_same_load_method(self):
self.assertEqual(self.lme_test_func(), "1")
self.assertEqual(self.lme_test_func(True), "1 flag")
if cinderjit:
if cinderjit and cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(is_jit_compiled(LoadMethodEliminationTests.lme_test_func))


Expand All @@ -5271,7 +5295,9 @@ def f1():
def func():
return f1() + f1()

cinderjit.force_compile(func)
self.assertEqual(func(), 10)

ops = cinderjit.get_function_hir_opcode_counts(func)
self.assertIsInstance(ops, dict)
self.assertEqual(ops.get("Return"), 1)
Expand Down
13 changes: 9 additions & 4 deletions Lib/test/test_jitlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def test_py_function(self) -> None:
self.assertIn(func.__qualname__, py_funcs[__name__])
self.assertNotIn(_no_jit_function.__qualname__, py_funcs[__name__])
meth(None)
self.assertTrue(cinderjit.is_jit_compiled(meth))
if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(meth))
func()
self.assertTrue(cinderjit.is_jit_compiled(func))
if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(func))
_no_jit_function()
self.assertFalse(cinderjit.is_jit_compiled(_no_jit_function))

Expand All @@ -67,7 +69,8 @@ def test_py_code(self) -> None:
self.assertIn(code_obj.co_firstlineno, py_code_objs[code_obj.co_name][thisfile])
self.assertNotIn(_no_jit_function.__code__.co_name, py_code_objs)
_jit_function_2()
self.assertTrue(cinderjit.is_jit_compiled(_jit_function_2))
if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(_jit_function_2))
_no_jit_function()
self.assertFalse(cinderjit.is_jit_compiled(_no_jit_function))

Expand All @@ -84,7 +87,9 @@ def inner_func():
self.assertFalse(cinderjit.is_jit_compiled(inner_func))
inner_func.__qualname__ += "_foo"
self.assertEqual(inner_func(), 24)
self.assertTrue(cinderjit.is_jit_compiled(inner_func))

if cinderjit.auto_jit_threshold() <= 1:
self.assertTrue(cinderjit.is_jit_compiled(inner_func))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions Makefile.pre.in
Original file line number Diff line number Diff line change
Expand Up @@ -1800,13 +1800,20 @@ define RUN_TESTCINDERJIT
# $(ASAN_TEST_ENV)$(TESTPYTHON) -X jit $(1) -X jit-multithreaded-compile-test -X jit-batch-compile-workers=10 $(srcdir)/Lib/test/multithreaded_compile_test.py $(TESTOPTS)
endef

define RUN_TESTCINDERJITAUTO
$(ASAN_TEST_ENV)$(TESTPYTHON) -X usepycompiler -X jit-auto=1000 -X jit-enable-inline-cache-stats-collection $(1) $(JIT_TEST_RUNNER) $(TESTOPTS)
endef

define RUN_TESTCINDERJIT_AUTOPROFILE
$(ASAN_TEST_ENV)$(TESTPYTHON) -X usepycompiler -X jit-auto=1000 -X jit-auto-profile=10 -X jit-enable-inline-cache-stats-collection $(1) $(JIT_TEST_RUNNER) $(TESTOPTS)
endef

testcinder_jit: @DEF_MAKE_RULE@ platform
$(call RUN_TESTCINDERJIT,)

testcinder_jit_auto: @DEF_MAKE_RULE@ platform
$(call RUN_TESTCINDERJITAUTO,)

testcinder_jit_profile: @DEF_MAKE_RULE@ platform
$(call RUN_TESTCINDERJIT_PROFILE,)

Expand Down

0 comments on commit 25ce292

Please sign in to comment.