diff --git a/armi/reactor/composites.py b/armi/reactor/composites.py index 4d2d73f27..2d48fa2c2 100644 --- a/armi/reactor/composites.py +++ b/armi/reactor/composites.py @@ -2701,6 +2701,28 @@ def _iterChildren( for c in self: yield from c._iterChildren(deep, generationNum - 1, checker) + def iterChildrenWithMaterials(self, *args, **kwargs) -> Iterator: + """Produce an iterator that also includes any materials found on descendants. + + Arguments are forwarded to :meth:`iterChildren` and control the depth of traversal + and filtering of objects. + + This is useful for sending state across MPI tasks where you need a more full + representation of the composite tree. Which includes the materials attached + to components. + """ + children = self.iterChildren(*args, **kwargs) + # Each entry is either (c, ) or (c, c.material) if the child has a material attribute + stitched = map( + lambda c: ( + (c,) if getattr(c, "material", None) is None else (c, c.material) + ), + children, + ) + # Iterator that iterates over each "sub" iterator. If we have ((c0, ), (c1, m1)), this produces a single + # iterator of (c0, c1, m1) + return itertools.chain.from_iterable(stitched) + def getChildren( self, deep=False, @@ -2770,22 +2792,15 @@ def getChildren( [grandchild1, grandchild3] """ - children = self.iterChildren( - deep=deep, generationNum=generationNum, predicate=predicate - ) if not includeMaterials: - return list(children) - # Each entry is either (c, ) or (c, c.material) if the child has a material attribute - stitched = map( - lambda c: ( - (c, ) if getattr(c, "material", None) is None else (c, c.material) - ), - children, - ) - # Iterator that iterates over each "sub" iterator. If we have ((c0, ), (c1, m1)), this produces a single - # iterator of (c0, c1, m1) - flattened = itertools.chain.from_iterable(stitched) - return list(flattened) + items = self.iterChildren( + deep=deep, generationNum=generationNum, predicate=predicate + ) + else: + items = self.iterChildrenWithMaterials( + deep=deep, generationNum=generationNum, predicate=predicate + ) + return list(items) def iterChildrenWithFlags(self, typeSpec: TypeSpec, exactMatch=False): """Produce an iterator over all children of a specific type.""" @@ -2856,8 +2871,11 @@ def syncMpiState(self): startTime = timeit.default_timer() # sync parameters... - allComps = [self] + self.getChildren(deep=True, includeMaterials=True) - allComps = [c for c in allComps if hasattr(c, "p")] + genItems = itertools.chain( + [self], + self.iterChildrenWithMaterials(deep=True), + ) + allComps = [c for c in genItems if hasattr(c, "p")] sendBuf = [c.p.getSyncData() for c in allComps] runLog.debug("syncMpiState has {} comps".format(len(allComps))) @@ -2962,7 +2980,11 @@ def _markSynchronized(self): SINCE_LAST_DISTRIBUTE_STATE. """ paramDefs = set() - for child in [self] + self.getChildren(deep=True, includeMaterials=True): + items = itertools.chain( + [self], + self.iterChildrenWithMaterials(deep=True), + ) + for child in items: # Materials don't have a "p" / Parameter attribute to sync if hasattr(child, "p"): # below reads as: assigned & everything_but(SINCE_LAST_DISTRIBUTE_STATE) @@ -3232,7 +3254,7 @@ class StateRetainer: """ - def __init__(self, composite, paramsToApply=None): + def __init__(self, composite: Composite, paramsToApply=None): """ Create an instance of a StateRetainer. @@ -3260,9 +3282,11 @@ def _enterExitHelper(self, func): ``backUp()`` or ``restoreBackup()``. """ paramDefs = set() - for child in [self.composite] + self.composite.getChildren( - deep=True, includeMaterials=True - ): + items = itertools.chain( + (self.composite,), + self.composite.iterChildrenWithMaterials(deep=True), + ) + for child in items: if hasattr(child, "p"): # materials don't have Parameters paramDefs.update(child.p.paramDefs)