Skip to content

Commit

Permalink
Merge branch 'feature/SOF-7413' into feature/SOF-7442
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 30, 2024
2 parents e763294 + 008264d commit 1acc6f8
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/py/mat3ra/made/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel

from .cell import Cell
from .utils import ArrayWithIds
from .utils import ArrayWithIds, get_overlapping_coordinates


class Basis(RoundNumericValuesMixin, BaseModel):
Expand Down Expand Up @@ -76,7 +76,15 @@ def to_crystal(self):
self.coordinates.map_array_in_place(self.cell.convert_point_to_crystal)
self.units = AtomicCoordinateUnits.crystal

def add_atom(self, element="Si", coordinate=[0.5, 0.5, 0.5]):
def add_atom(self, element="Si", coordinate=None, force=False):
if coordinate is None:
coordinate = [0, 0, 0]
if get_overlapping_coordinates(coordinate, self.coordinates.values, threshold=0.01):
if force:
print(f"Warning: Overlapping coordinates found for {coordinate}. Adding atom anyway.")
else:
print(f"Warning: Overlapping coordinates found for {coordinate}. Not adding atom.")
return
self.elements.add_item(element)
self.coordinates.add_item(coordinate)

Expand Down
9 changes: 7 additions & 2 deletions src/py/mat3ra/made/tools/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .build.passivation.enums import SurfaceTypes
from .convert import decorator_convert_material_args_kwargs_to_atoms, to_pymatgen
from .third_party import ASEAtoms, PymatgenIStructure, PymatgenVoronoiNN
from .utils import decorator_handle_periodic_boundary_conditions


@decorator_convert_material_args_kwargs_to_atoms
Expand Down Expand Up @@ -299,8 +300,10 @@ def get_nearest_neighbors_atom_indices(
site_index = len(structure.sites) - 1

remove_dummy_atom = True

neighbors = voronoi_nn.get_nn_info(structure, site_index)
try:
neighbors = voronoi_nn.get_nn_info(structure, site_index)
except ValueError:
return None
neighboring_atoms_pymatgen_ids = [n["site_index"] for n in neighbors]
if remove_dummy_atom:
structure.remove_sites([-1])
Expand Down Expand Up @@ -389,6 +392,7 @@ def shadowing_check(z: float, neighbors_indices: List[int], surface: SurfaceType
)


@decorator_handle_periodic_boundary_conditions(cutoff=0.1)
def get_surface_atom_indices(
material: Material, surface: SurfaceTypes = SurfaceTypes.TOP, shadowing_radius: float = 2.5, depth: float = 5
) -> List[int]:
Expand Down Expand Up @@ -455,6 +459,7 @@ def get_coordination_numbers(
return coordination_numbers


@decorator_handle_periodic_boundary_conditions(cutoff=0.1)
def get_undercoordinated_atom_indices(
material: Material,
indices: List[int],
Expand Down
86 changes: 86 additions & 0 deletions src/py/mat3ra/made/tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Callable, List, Optional

import numpy as np
from mat3ra.made.material import Material
from mat3ra.made.utils import ArrayWithIds
from mat3ra.utils.matrix import convert_2x2_to_3x3

from ..third_party import PymatgenStructure
Expand Down Expand Up @@ -106,3 +108,87 @@ def transform_coordinate_to_supercell(
if reverse:
converted_array = (np_coordinate - np_translation_vector) * np_scaling_factor
return converted_array.tolist()


def decorator_handle_periodic_boundary_conditions(cutoff):
"""
Decorator to handle periodic boundary conditions.
Copies atoms near boundaries within the cutoff distance beyond the opposite side of the cell
creating the effect of periodic boundary conditions for edge atoms.
Results of the function are filtered to remove atoms or coordinates outside the original cell.
Args:
cutoff (float): The cutoff distance for a border slice in crystal coordinates.
Returns:
Callable: The decorated function.
"""

def decorator(func):
@wraps(func)
def wrapper(material, *args, **kwargs):
augmented_material, last_id = augment_material_with_periodic_images(material, cutoff)
result = func(augmented_material, *args, **kwargs)

if isinstance(result, list):
if all(isinstance(x, int) for x in result):
result = [id for id in result if id <= last_id]
elif all(isinstance(x, list) and len(x) == 3 for x in result):
result = [coord for coord in result if all(0 <= c < 1 for c in coord)]
return result

return wrapper

return decorator


def filter_and_translate(coordinates: np.ndarray, elements: np.ndarray, axis: int, cutoff: float, direction: int):
"""
Filter and translate atom coordinates based on the axis and direction.
Args:
coordinates (np.ndarray): The coordinates of the atoms.
elements (np.ndarray): The elements of the atoms.
axis (int): The axis to filter and translate.
cutoff (float): The cutoff value for filtering.
direction (int): The direction to translate.
Returns:
Tuple[np.ndarray, np.ndarray]: The filtered and translated coordinates and elements.
"""
mask = (coordinates[:, axis] < cutoff) if direction == 1 else (coordinates[:, axis] > (1 - cutoff))
filtered_coordinates = coordinates[mask]
filtered_elements = elements[mask]
translation_vector = np.zeros(3)
translation_vector[axis] = direction
translated_coordinates = filtered_coordinates + translation_vector
return translated_coordinates, filtered_elements


def augment_material_with_periodic_images(material: Material, cutoff: float = 0.1):
"""
Augment the material's dataset by adding atoms from periodic images near boundaries.
Args:
material (Material): The material to augment.
cutoff (float): The cutoff value for filtering atoms near boundaries.
Returns:
Tuple[Material, int]: The augmented material and the original count of atoms.
"""
last_id = material.basis.coordinates.ids[-1]
coordinates = np.array(material.basis.coordinates.values)
elements = np.array(material.basis.elements.values)
augmented_material = material.clone()
new_basis = augmented_material.basis.copy()

for axis in range(3):
for direction in [-1, 1]:
translated_coords, translated_elems = filter_and_translate(coordinates, elements, axis, cutoff, direction)
for coord, elem in zip(translated_coords, translated_elems):
new_basis.add_atom(elem, coord)

augmented_material.basis = new_basis
return augmented_material, last_id
17 changes: 17 additions & 0 deletions src/py/mat3ra/made/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ def get_center_of_coordinates(coordinates: List[List[float]]) -> List[float]:
return np.mean(np.array(coordinates), axis=0).tolist()


def get_overlapping_coordinates(
coordinate: List[float], coordinates: List[List[float]], threshold: float = 0.01
) -> List[List[float]]:
"""
Find coordinates that are within a certain threshold of a given coordinate.
Args:
coordinate (List[float]): The coordinate.
coordinates (List[List[float]]): The list of coordinates.
threshold (float): The threshold.
Returns:
List[List[float]]: The list of overlapping coordinates.
"""
return [c for c in coordinates if np.linalg.norm(np.array(c) - np.array(coordinate)) < threshold]


class ValueWithId(RoundNumericValuesMixin, BaseModel):
id: int = 0
value: Any = None
Expand Down

0 comments on commit 1acc6f8

Please sign in to comment.