diff --git a/dsl/pace/dsl/stencil.py b/dsl/pace/dsl/stencil.py index bcfea4735..146a445e0 100644 --- a/dsl/pace/dsl/stencil.py +++ b/dsl/pace/dsl/stencil.py @@ -317,8 +317,6 @@ def __init__( ) self.stencil_object: Optional[gt4py.StencilObject] = None - self._argument_names = tuple(inspect.getfullargspec(func).args) - if "dace" in self.stencil_config.compilation_config.backend: dace.Config.set( "default_build_folder", @@ -336,39 +334,33 @@ def __init__( # If we orchestrate, move the compilation at call time to make sure # disable_codegen do not lead to call to uncompiled stencils, which fails # silently - if self.stencil_config.dace_config.is_dace_orchestrated(): - self.stencil_object = gtscript.lazy_stencil( - definition=func, - externals=externals, - **stencil_kwargs, - build_info=(build_info := {}), # type: ignore - ) - else: - compilation_config = stencil_config.compilation_config - if ( - compilation_config.use_minimal_caching - and not compilation_config.is_compiling - and compilation_config.run_mode != RunMode.Run - ): - block_waiting_for_compilation(MPI.COMM_WORLD, compilation_config) - - self.stencil_object = gtscript.stencil( - definition=func, - externals=externals, - **stencil_kwargs, - build_info=(build_info := {}), + self.stencil_object = gtscript.lazy_stencil( + definition=func, + externals=externals, + **stencil_kwargs, + build_info=(build_info := {}), # type: ignore + ) + if not self.stencil_config.dace_config.is_dace_orchestrated(): + # Trigger compilation + do_block_waiting = ( + self.stencil_config.compilation_config.use_minimal_caching + and not self.stencil_config.compilation_config.is_compiling + and self.stencil_config.compilation_config.run_mode != RunMode.Run ) - if ( - compilation_config.use_minimal_caching - and compilation_config.is_compiling - and compilation_config.run_mode != RunMode.Run - ): + if do_block_waiting: + block_waiting_for_compilation( + MPI.COMM_WORLD, self.stencil_config.compilation_config + ) + + # Referencing the implementation attribute triggers compilation + self.stencil_object.implementation + + if do_block_waiting: unblock_waiting_tiles(MPI.COMM_WORLD) - self._timing_collector.build_info[ - _stencil_object_name(self.stencil_object) - ] = build_info + obj_name = _stencil_object_name(self.stencil_object) + self._timing_collector.build_info[obj_name] = build_info field_info = self.stencil_object.field_info self._field_origins: Dict[ diff --git a/dsl/pace/dsl/stencil_config.py b/dsl/pace/dsl/stencil_config.py index 9710784d6..a473042ab 100644 --- a/dsl/pace/dsl/stencil_config.py +++ b/dsl/pace/dsl/stencil_config.py @@ -255,12 +255,18 @@ def stencil_kwargs( "name": func.__module__ + "." + func.__name__, **self.backend_opts, } + if not self.is_gpu_backend: kwargs.pop("device_sync", None) + if skip_passes or kwargs.get("skip_passes", ()): kwargs["oir_pipeline"] = StencilConfig._get_oir_pipeline( list(kwargs.pop("skip_passes", ())) + list(skip_passes) # type: ignore ) + + if self.compilation_config.run_mode == RunMode.Run: + kwargs["raise_if_not_cached"] = True + return kwargs @property