Skip to content

Commit

Permalink
update: fix exessive after decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 29, 2024
1 parent e7df9f4 commit cbc4131
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions src/py/mat3ra/made/tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from mat3ra.made.material import Material
from mat3ra.made.utils import ArrayWithIds
from mat3ra.utils.matrix import convert_2x2_to_3x3

from ..third_party import PymatgenStructure
Expand Down Expand Up @@ -110,18 +111,32 @@ def transform_coordinate_to_supercell(


def decorator_handle_periodic_boundary_conditions(cutoff):
"""
Decorator to handle periodic boundary conditions.
Copies atoms near boundaries within the cutoff distance beyond the opposite side of the cell
creating the effect of periodic boundary conditions for edge atoms.
Results of the function are filtered to remove atoms or coordinates outside the original cell.
Args:
cutoff (float): The cutoff distance for a border slice in crystal coordinates.
Returns:
Callable: The decorated function.
"""

def decorator(func):
@wraps(func)
def wrapper(material, *args, **kwargs):
augmented_material, original_count = augment_material_with_periodic_images(material, cutoff)
augmented_material, last_id = augment_material_with_periodic_images(material, cutoff)
result = func(augmented_material, *args, **kwargs)

if isinstance(result, list):
result = [idx for idx in result if idx < original_count]

if isinstance(result, list) and all(isinstance(coord, float) for coord in result):
result = [coordinate for coordinate in result if 0 <= coordinate <= 1]

if all(isinstance(x, int) for x in result):
result = [id for id in result if id <= last_id]
elif all(isinstance(x, list) and len(x) == 3 for x in result):
result = [coord for coord in result if all(0 <= c < 1 for c in coord)]
return result

return wrapper
Expand Down Expand Up @@ -163,24 +178,18 @@ def augment_material_with_periodic_images(material: Material, cutoff: float = 0.
Returns:
Tuple[Material, int]: The augmented material and the original count of atoms.
"""
original_count = len(material.basis.coordinates.values)
last_id = material.basis.coordinates.ids[-1]
coordinates = np.array(material.basis.coordinates.values)
elements = np.array(material.basis.elements.values)
augmented_coords = coordinates.tolist()
augmented_elems = elements.tolist()
augmented_material = material.clone()
new_basis = augmented_material.basis.copy()

for axis in range(3):
for direction in [-1, 1]:
translated_coords, translated_elems = filter_and_translate(coordinates, elements, axis, cutoff, direction)
augmented_coords.extend(translated_coords)
augmented_elems.extend(translated_elems)
coordinates = np.array(augmented_coords)
elements = np.array(augmented_elems)

augmented_material = material.clone()
new_basis = augmented_material.basis.copy()
for i, coord in enumerate(augmented_coords):
new_basis.add_atom(augmented_elems[i], coord)
for coord, elem in zip(translated_coords, translated_elems):
if not any(np.allclose(coord, existing_coord) for existing_coord in new_basis.coordinates.values):
new_basis.add_atom(elem, coord)

augmented_material.basis = new_basis
return augmented_material, original_count
return augmented_material, last_id

0 comments on commit cbc4131

Please sign in to comment.