Skip to content

Commit

Permalink
For fields getitem, use only valid cells internally
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgrote committed Oct 13, 2023
1 parent 419d47a commit 2b54f40
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions Python/pywarpx/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _get_field(self, mfi):
device_arr = device_arr[tuple([slice(ng, -ng) for ng in nghosts[:self.dim]])]
return device_arr

def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop, with_internal_ghosts):
"""Return the slices where the block intersects with the global slice.
If the block does not intersect, return None.
This also shifts the block slices by the number of ghost cells in the
Expand All @@ -338,6 +338,9 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
The maximum component index of the global slice.
These can be negative.
with_internal_ghosts: bool
Whether the internal ghosts are included in the slices
Returns
-------
block_slices:
Expand All @@ -347,12 +350,23 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
The slice of the intersection relative to the global array where the data from individual block will go
"""
box = mfi.tilebox()
if self.include_ghosts:
if self.include_ghosts and with_internal_ghosts:
box.grow(self.mf.n_grow_vect())

ilo = self._get_indices(box.small_end, 0)
ihi = self._get_indices(box.big_end, 0)

if self.include_ghosts and not with_internal_ghosts:
# Only include the ghost cells on the outer edge of the full domain.
# Note that this could be done above, but needs box.growLo and box.growHi.
nghosts = self._get_indices(self.mf.n_grow_vect(), 0)
min_box = self.mf.box_array().minimal_box()
for i in range(3):
if ilo[i] == min_box.small_end[i]:
ilo[i] -= nghosts[i]
if ihi[i] == min_box.big_end[i]:
ihi[i] += nghosts[i]

# Add 1 to the upper end to be consistent with the slicing notation
ihi_p1 = [i + 1 for i in ihi]
i1 = np.maximum(starts, ilo)
Expand Down Expand Up @@ -423,7 +437,7 @@ def __getitem__(self, index):
stops = [ixstop, iystop, izstop]
datalist = []
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop)
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop, False)
if global_slices is not None:
# Note that the array will always have 4 dimensions.
device_arr = self._get_field(mfi)
Expand Down Expand Up @@ -526,7 +540,7 @@ def __setitem__(self, index, value):
starts = [ixstart, iystart, izstart]
stops = [ixstop, iystop, izstop]
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop)
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop, True)
if global_slices is not None:
mf_arr = self._get_field(mfi)
if isinstance(value, np.ndarray):
Expand Down

0 comments on commit 2b54f40

Please sign in to comment.