diff --git a/src/py/mat3ra/made/tools/utils/__init__.py b/src/py/mat3ra/made/tools/utils/__init__.py index 4c3dbd0b..1ac833cc 100644 --- a/src/py/mat3ra/made/tools/utils/__init__.py +++ b/src/py/mat3ra/made/tools/utils/__init__.py @@ -3,6 +3,7 @@ import numpy as np from mat3ra.made.material import Material +from mat3ra.made.utils import ArrayWithIds from mat3ra.utils.matrix import convert_2x2_to_3x3 from ..third_party import PymatgenStructure @@ -110,18 +111,32 @@ def transform_coordinate_to_supercell( def decorator_handle_periodic_boundary_conditions(cutoff): + """ + Decorator to handle periodic boundary conditions. + + Copies atoms near boundaries within the cutoff distance beyond the opposite side of the cell + creating the effect of periodic boundary conditions for edge atoms. + + Results of the function are filtered to remove atoms or coordinates outside the original cell. + + Args: + cutoff (float): The cutoff distance for a border slice in crystal coordinates. + + Returns: + Callable: The decorated function. + """ + def decorator(func): @wraps(func) def wrapper(material, *args, **kwargs): - augmented_material, original_count = augment_material_with_periodic_images(material, cutoff) + augmented_material, last_id = augment_material_with_periodic_images(material, cutoff) result = func(augmented_material, *args, **kwargs) if isinstance(result, list): - result = [idx for idx in result if idx < original_count] - - if isinstance(result, list) and all(isinstance(coord, float) for coord in result): - result = [coordinate for coordinate in result if 0 <= coordinate <= 1] - + if all(isinstance(x, int) for x in result): + result = [id for id in result if id <= last_id] + elif all(isinstance(x, list) and len(x) == 3 for x in result): + result = [coord for coord in result if all(0 <= c < 1 for c in coord)] return result return wrapper @@ -163,24 +178,18 @@ def augment_material_with_periodic_images(material: Material, cutoff: float = 0. Returns: Tuple[Material, int]: The augmented material and the original count of atoms. """ - original_count = len(material.basis.coordinates.values) + last_id = material.basis.coordinates.ids[-1] coordinates = np.array(material.basis.coordinates.values) elements = np.array(material.basis.elements.values) - augmented_coords = coordinates.tolist() - augmented_elems = elements.tolist() + augmented_material = material.clone() + new_basis = augmented_material.basis.copy() for axis in range(3): for direction in [-1, 1]: translated_coords, translated_elems = filter_and_translate(coordinates, elements, axis, cutoff, direction) - augmented_coords.extend(translated_coords) - augmented_elems.extend(translated_elems) - coordinates = np.array(augmented_coords) - elements = np.array(augmented_elems) - - augmented_material = material.clone() - new_basis = augmented_material.basis.copy() - for i, coord in enumerate(augmented_coords): - new_basis.add_atom(augmented_elems[i], coord) + for coord, elem in zip(translated_coords, translated_elems): + if not any(np.allclose(coord, existing_coord) for existing_coord in new_basis.coordinates.values): + new_basis.add_atom(elem, coord) augmented_material.basis = new_basis - return augmented_material, original_count + return augmented_material, last_id