Skip to content

Commit

Permalink
update: simplify and OOP
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 27, 2024
1 parent 1623de1 commit 6199ed4
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 34 deletions.
55 changes: 39 additions & 16 deletions src/py/mat3ra/made/tools/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,40 +376,63 @@ def get_surface_atom_indices(
for idx, (x, y, z) in enumerate(coordinates):
if height_check(z, z_extremum, depth, surface):
neighbors_indices = kd_tree.query_ball_point([x, y, z], r=shadowing_radius)
print("neighbors_indices", type(neighbors_indices), neighbors_indices)
if shadowing_check(z, neighbors_indices, surface, coordinates):
exposed_atoms_indices.append(ids[idx])

return exposed_atoms_indices


def get_undercoordinated_atom_indices(
material: Material, surface: SurfaceTypes = SurfaceTypes.TOP, coordination_number: int = 3, cutoff: float = 3.0
def get_coordination_numbers(
material: Material,
indices: Optional[List[int]] = None,
cutoff: float = 3.0,
) -> List[int]:
"""
Identify undercoordinated atoms on the top or bottom surface of the material.
Calculate the coordination numbers of atoms in the material.
Args:
material (Material): Material object to get undercoordinated atoms from.
surface (SurfaceTypes): Specify "top" or "bottom" to detect the respective surface atoms.
coordination_number (int): The coordination number to detect undercoordinated atoms.
material (Material): Material object to calculate coordination numbers for.
indices (List[int]): List of atom indices to calculate coordination numbers for.
cutoff (float): The cutoff radius for identifying neighbors.
Returns:
List[int]: List of indices of undercoordinated surface atoms.
List[int]: List of coordination numbers for each atom in the material.
"""
new_material = material.clone()
new_material.to_cartesian()
if indices is not None:
new_material.basis.coordinates.filter_by_indices(indices)
coordinates = np.array(new_material.basis.coordinates.values)
ids = new_material.basis.coordinates.ids
kd_tree = cKDTree(coordinates)

z_extremum = np.max(coordinates[:, 2]) if surface == SurfaceTypes.TOP else np.min(coordinates[:, 2])

undercoordinated_atoms_indices = []
coordination_numbers = []
for idx, (x, y, z) in enumerate(coordinates):
if z == z_extremum:
neighbors = kd_tree.query_ball_point([x, y, z], r=cutoff)
if len(neighbors) < coordination_number:
undercoordinated_atoms_indices.append(ids[idx])
neighbors = kd_tree.query_ball_point([x, y, z], r=cutoff)
# Explicitly remove the atom itself from the list of neighbors
neighbors = [n for n in neighbors if n != idx]
coordination_numbers.append(len(neighbors))

return coordination_numbers


def get_undercoordinated_atom_indices(
material: Material,
indices: List[int],
cutoff: float = 3.0,
coordination_threshold: int = 3,
) -> List[int]:
"""
Identify undercoordinated atoms among the specified indices in the material.
Args:
material (Material): Material object to identify undercoordinated atoms in.
indices (List[int]): List of atom indices to check for undercoordination.
cutoff (float): The cutoff radius for identifying neighbors.
coordination_threshold (int): The coordination number threshold for undercoordination.
Returns:
List[int]: List of indices of undercoordinated atoms.
"""
coordination_numbers = get_coordination_numbers(material, indices, cutoff)
undercoordinated_atoms_indices = [i for i, cn in enumerate(coordination_numbers) if cn <= coordination_threshold]
return undercoordinated_atoms_indices
65 changes: 47 additions & 18 deletions src/py/mat3ra/made/tools/build/passivation/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from pydantic import BaseModel

from .enums import SurfaceTypes
from ...analyze import get_surface_atom_indices, get_undercoordinated_atom_indices, get_nearest_neighbors_atom_indices
from ...analyze import (
get_surface_atom_indices,
get_undercoordinated_atom_indices,
get_nearest_neighbors_atom_indices,
get_coordination_numbers,
)
from ...modify import translate_to_z_level
from ...build import BaseBuilder
from .configuration import (
Expand Down Expand Up @@ -117,14 +122,13 @@ def _get_passivant_coordinates(
return (np.array(surface_atoms_coordinates) + np.array(passivant_bond_vector_crystal)).tolist()


class UndercoordinationPassivationBuilderParameters(BaseModel):
class UndercoordinationPassivationBuilderParameters(SurfacePassivationBuilderParameters):
"""
Parameters for the UndercoordinationPassivationBuilder.
Args:
coordination_threshold (int): The coordination threshold for undercoordination.
"""

cutoff: float = 3.0
coordination_threshold: int = 3


Expand All @@ -135,43 +139,51 @@ class UndercoordinationPassivationBuilder(PassivationBuilder):
Detects atoms with coordination number below a threshold and passivates them.
"""

build_parameters: UndercoordinationPassivationBuilderParameters = UndercoordinationPassivationBuilderParameters()
_BuildParametersType = UndercoordinationPassivationBuilderParameters
_DefaultBuildParameters = UndercoordinationPassivationBuilderParameters()

def create_passivated_material(self, configuration: PassivationConfiguration) -> Material:
material = super().create_passivated_material(configuration)
passivant_coordinates_values = self._get_passivant_coordinates(material, configuration)
surface_atoms_indices = get_surface_atom_indices(
material=material,
surface=SurfaceTypes.TOP,
shadowing_radius=self.build_parameters.shadowing_radius,
depth=self.build_parameters.depth,
)
undercoordinated_atoms_indices = get_undercoordinated_atom_indices(
material=material,
indices=surface_atoms_indices,
cutoff=self.build_parameters.shadowing_radius,
coordination_threshold=self.build_parameters.coordination_threshold,
)
passivant_coordinates_values = self._get_passivant_coordinates(
material, configuration, undercoordinated_atoms_indices
)
return self._add_passivant_atoms(material, passivant_coordinates_values, configuration.passivant)

def _get_passivant_coordinates(self, material: Material, configuration: PassivationConfiguration):
def _get_passivant_coordinates(
self, material: Material, configuration: PassivationConfiguration, undercoordinated_atoms_indices: list
):
"""
Calculate the coordinates for placing passivants based on the specified edge type.
Calculate the coordinates for placing passivating atoms based on the specified edge type.
Args:
material (Material): Material to passivate.
configuration (SurfacePassivationConfiguration): Configuration for passivation.
undercoordinated_atoms_indices (list): Indices of undercoordinated atoms.
"""
undercoordinated_atoms_indices = get_undercoordinated_atom_indices(
material=material,
surface=SurfaceTypes.TOP,
coordination_number=self.build_parameters.coordination_threshold,
cutoff=self.build_parameters.cutoff,
)

passivant_coordinates = []

for idx in undercoordinated_atoms_indices:
nearest_neighbors = get_nearest_neighbors_atom_indices(
material=material,
coordinate=material.basis.coordinates.get_element_value_by_index(idx),
cutoff=self.build_parameters.cutoff,
cutoff=self.build_parameters.shadowing_radius,
)

if nearest_neighbors is None:
continue
average_coordinate = np.mean(
[material.basis.coordinates.get_element_value_by_index(i) for i in nearest_neighbors], axis=0
)

bond_vector = material.basis.coordinates.get_element_value_by_index(idx) - average_coordinate
bond_vector = bond_vector / np.linalg.norm(bond_vector) * configuration.bond_length
passivant_bond_vector_crystal = material.basis.cell.convert_point_to_crystal(bond_vector)
Expand All @@ -181,3 +193,20 @@ def _get_passivant_coordinates(self, material: Material, configuration: Passivat
)

return passivant_coordinates

def get_coordination_numbers(self, material: Material):
"""
Get the coordination numbers for all atoms in the material.
Args:
material (Material): The material object.
Returns:
set: The coordination numbers for all atoms in the material.
"""

coordination_numbers = set(
get_coordination_numbers(material=material, cutoff=self.build_parameters.shadowing_radius)
)
print("coordination numbers:", coordination_numbers)
return coordination_numbers

0 comments on commit 6199ed4

Please sign in to comment.