Skip to content

Commit

Permalink
Add some explanatory comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 23, 2024
1 parent 99c6829 commit 71c5011
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,9 @@ def map(self, r: Range) -> Optional[Range]:
while src_i < self.src.dims():
assert dst_i < self.dst.dims()

# Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have
# a mapping).
# TODO: It's possible to do this in a O(max(|src|, |dst|)) loop instead of O(|src| * |dst|).
src_j, dst_j = None, None
for sj in range(src_i + 1, self.src.dims() + 1):
for dj in range(dst_i + 1, self.dst.dims() + 1):
Expand All @@ -1404,12 +1407,14 @@ def map(self, r: Range) -> Optional[Range]:
continue
break
if src_j is None:
# Somehow, we couldn't find a matching segment. This should have been caught earlier.
return None

# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
src_segment, dst_segment, r_segment = Range(self.src.ranges[src_i: src_j]), Range(
self.dst.ranges[dst_i: dst_j]), Range(r.ranges[src_i: src_j])
src_segment = Range(self.src.ranges[src_i: src_j])
dst_segment = Range(self.dst.ranges[dst_i: dst_j])
r_segment = Range(r.ranges[src_i: src_j])
if r_segment.volume_exact() == 1:
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
# Compute the local 1D coordinate of the point on `src`.
loc = 0
for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges),
Expand All @@ -1427,7 +1432,7 @@ def map(self, r: Range) -> Optional[Range]:
# its entirety too.
out.extend(self.dst.ranges[dst_i:dst_j])
elif src_j - src_i == 1 and dst_j - dst_i == 1:
# If the segment lengths on both sides are just 1, the mapping is easy to compute.
# If the segment lengths on both sides are just 1, the mapping is easy to compute -- it's just a shift.
sb, se, ss = self.src.ranges[src_i]
db, de, ds = self.dst.ranges[dst_i]
b, e, s = r.ranges[src_i]
Expand Down

0 comments on commit 71c5011

Please sign in to comment.