Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Kernel cache changes #351

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,21 @@ jobs:
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1 && export WAVE_CACHE_DIR=$PWD/.wave
rm -rf ./.wave
pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/runtime
WAVE_CACHE_ON=1 pytest --capture=tee-sys -vv ./tests/kernel/wave/runtime

- name: Run e2e tests on AMD GPU MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1 && export WAVE_CACHE_DIR=$PWD/.wave
rm -rf ./.wave
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ --ignore=./tests/kernel/wave/runtime
export WAVE_RUN_E2E_TESTS=1
WAVE_CACHE_ON=0 pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/

- name: Run e2e tests on AMD GPU MI250
if: "contains(matrix.os, 'mi250') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1 && export WAVE_CACHE_DIR=$PWD/.wave
rm -rf ./.wave
pytest -n 2 --capture=tee-sys -vv ./tests/kernel/wave/ --ignore=./tests/kernel/wave/runtime
export WAVE_RUN_E2E_TESTS=1
WAVE_CACHE_ON=0 pytest -n 2 --capture=tee-sys -vv ./tests/kernel/wave/

- name: Run LIT tests
if: ${{ !cancelled() }}
Expand Down
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
WAVE_CACHE_LIMIT = int(os.environ.get("WAVE_CACHE_LIMIT", 16))


def is_cache_enabled() -> bool:
return bool(WAVE_CACHE_ON)


@dataclass
class WaveCache:
"""
Expand Down
60 changes: 34 additions & 26 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
KernelRegionGraph,
Launchable,
)
from .cache import get_cache_manager, invoke_cached_kernel
from .cache import is_cache_enabled, get_cache_manager, invoke_cached_kernel

import sympy

Expand Down Expand Up @@ -373,30 +373,32 @@ def test_execute(self, args, kwargs):
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)

# Get cached kernel when available.
cache_manager = get_cache_manager()
# TODO: Move use_scheduling, use_scheduling_barriers, etc. into the config so everything is contained there.
kernel_hash = cache_manager.get_hash(
self.constraints,
self._f,
IndexingContext.current().subs,
dynamic_symbols,
config,
use_scheduling=use_scheduling,
use_scheduling_barriers=use_scheduling_barriers,
run_bench=run_bench,
)
cached_kernel = cache_manager.load_kernel(kernel_hash)
if cached_kernel and (run or run_bench):
invoke_cached_kernel(
cached_kernel,
args,
config,
cache_enabled = is_cache_enabled()
if cache_enabled:
cache_manager = get_cache_manager()
# TODO: Move use_scheduling, use_scheduling_barriers, etc. into the config so everything is contained there.
kernel_hash = cache_manager.get_hash(
self.constraints,
self._f,
IndexingContext.current().subs,
dynamic_symbols,
dynamic_symbols_map,
run,
run_bench,
config,
use_scheduling=use_scheduling,
use_scheduling_barriers=use_scheduling_barriers,
run_bench=run_bench,
)
return cached_kernel
cached_kernel = cache_manager.load_kernel(kernel_hash)
if cached_kernel and (run or run_bench):
invoke_cached_kernel(
cached_kernel,
args,
config,
dynamic_symbols,
dynamic_symbols_map,
run,
run_bench,
)
return cached_kernel

# Recompile from kernel scratch if not found in cache.
(
Expand Down Expand Up @@ -439,9 +441,15 @@ def test_execute(self, args, kwargs):
binding.kernel_buffer_type.usage
for binding in kernel_sig.kernel_buffer_bindings
]
cache_manager.store_kernel(
compiled_wave_vmfb, kernel_usages, mb.module_op.get_asm(), kernel_hash
)

if cache_enabled:
cache_manager.store_kernel(
compiled_wave_vmfb,
kernel_usages,
mb.module_op.get_asm(),
kernel_hash,
)

invoke_vmfb(
compiled_wave_vmfb,
"isolated_benchmark",
Expand Down
8 changes: 8 additions & 0 deletions tests/kernel/wave/runtime/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import iree.turbine.kernel.wave as tkw

from iree.turbine.kernel.wave.cache import (
is_cache_enabled,
get_cache_manager,
reset_cache_manager,
WaveCache,
Expand All @@ -41,6 +42,10 @@
_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")

require_cache = pytest.mark.skipif(
not is_cache_enabled(), reason="filesystem cache is disabled"
)


def generate_attention_kernel(constraints: list[Constraint]):
# Input sizes
Expand Down Expand Up @@ -114,6 +119,7 @@ def repeat(


@require_e2e
@require_cache
def testSameConfig(request):
reset_cache_manager()
shape = (8, 128, 128, 64, 256)
Expand Down Expand Up @@ -221,6 +227,7 @@ def testSameConfig(request):


@require_e2e
@require_cache
def testDifferentDynamicSameBlock(request):
reset_cache_manager()
# Input sizes
Expand Down Expand Up @@ -382,6 +389,7 @@ def testDifferentDynamicSameBlock(request):


@require_e2e
@require_cache
def testSameSizeDifferentBlock(request):
reset_cache_manager()
shape = (8, 128, 128, 64, 256)
Expand Down
Loading