Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use raise_if_not_cached when RunMode.Run #333

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 23 additions & 31 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def __init__(
)
self.stencil_object: Optional[gt4py.StencilObject] = None

self._argument_names = tuple(inspect.getfullargspec(func).args)

Comment on lines -320 to -321
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exact line was run twice.

if "dace" in self.stencil_config.compilation_config.backend:
dace.Config.set(
"default_build_folder",
Expand All @@ -336,39 +334,33 @@ def __init__(
# If we orchestrate, move the compilation at call time to make sure
# disable_codegen do not lead to call to uncompiled stencils, which fails
# silently
if self.stencil_config.dace_config.is_dace_orchestrated():
self.stencil_object = gtscript.lazy_stencil(
definition=func,
externals=externals,
**stencil_kwargs,
build_info=(build_info := {}), # type: ignore
)
else:
compilation_config = stencil_config.compilation_config
if (
compilation_config.use_minimal_caching
and not compilation_config.is_compiling
and compilation_config.run_mode != RunMode.Run
):
block_waiting_for_compilation(MPI.COMM_WORLD, compilation_config)

self.stencil_object = gtscript.stencil(
definition=func,
externals=externals,
**stencil_kwargs,
build_info=(build_info := {}),
self.stencil_object = gtscript.lazy_stencil(
definition=func,
externals=externals,
**stencil_kwargs,
build_info=(build_info := {}), # type: ignore
)
if not self.stencil_config.dace_config.is_dace_orchestrated():
# Trigger compilation
do_block_waiting = (
self.stencil_config.compilation_config.use_minimal_caching
and not self.stencil_config.compilation_config.is_compiling
and self.stencil_config.compilation_config.run_mode != RunMode.Run
)

if (
compilation_config.use_minimal_caching
and compilation_config.is_compiling
and compilation_config.run_mode != RunMode.Run
):
if do_block_waiting:
block_waiting_for_compilation(
MPI.COMM_WORLD, self.stencil_config.compilation_config
)

# Referencing the implementation attribute triggers compilation
self.stencil_object.implementation

if do_block_waiting:
unblock_waiting_tiles(MPI.COMM_WORLD)

self._timing_collector.build_info[
_stencil_object_name(self.stencil_object)
] = build_info
obj_name = _stencil_object_name(self.stencil_object)
self._timing_collector.build_info[obj_name] = build_info
field_info = self.stencil_object.field_info

self._field_origins: Dict[
Expand Down
6 changes: 6 additions & 0 deletions dsl/pace/dsl/stencil_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,18 @@ def stencil_kwargs(
"name": func.__module__ + "." + func.__name__,
**self.backend_opts,
}

if not self.is_gpu_backend:
kwargs.pop("device_sync", None)

if skip_passes or kwargs.get("skip_passes", ()):
kwargs["oir_pipeline"] = StencilConfig._get_oir_pipeline(
list(kwargs.pop("skip_passes", ())) + list(skip_passes) # type: ignore
)

if self.compilation_config.run_mode == RunMode.Run:
kwargs["raise_if_not_cached"] = True

return kwargs

@property
Expand Down