Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
new znframe versionn
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jan 19, 2024
1 parent f019777 commit 7121215
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znframe"
version = "0.1.4"
version = "0.1.5"
description = "ZnFrame - ASE-like Interface based on dataclasses"
authors = ["zincwarecode <[email protected]>"]
readme = "README.md"
Expand Down
38 changes: 32 additions & 6 deletions znframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
from copy import deepcopy
import json
import typing as t
import enum

from znframe.bonds import ASEComputeBonds


class ComputeProperties(enum.Enum):
bonds = "bonds"
radii = "radii"
colors = "colors"


def _cell_to_array(cell: t.Union[np.ndarray, ase.cell.Cell]) -> np.ndarray:
if isinstance(cell, np.ndarray):
return cell
Expand Down Expand Up @@ -79,21 +86,40 @@ class Frame:
converter=_cell_to_array, eq=cmp_using(np.array_equal), default=np.zeros(3)
)

recompute: t.List[ComputeProperties] = field(
factory=lambda: [
ComputeProperties.bonds,
ComputeProperties.radii,
ComputeProperties.colors,
]
)

def __attrs_post_init__(self):
if ComputeProperties.bonds in self.recompute:
self.connectivity = None
if ComputeProperties.radii in self.recompute:
self.arrays.pop("radii", None)
if ComputeProperties.colors in self.recompute:
self.arrays.pop("colors", None)
if self.connectivity is None:
ase_bond_calculator = ASEComputeBonds()
self.connectivity = ase_bond_calculator.build_graph(self.to_atoms())
self.connectivity = ase_bond_calculator.get_bonds(self.connectivity)

if "colors" not in self.arrays:
self.arrays["colors"] = [
rgb2hex(jmol_colors[number]) for number in self.numbers
]
self.arrays["colors"] = np.array(
[rgb2hex(jmol_colors[number]) for number in self.numbers]
)
if "radii" not in self.arrays:
self.arrays["radii"] = [get_radius(number) for number in self.numbers]
self.arrays["radii"] = np.array(
[get_radius(number) for number in self.numbers]
)

@classmethod
def from_atoms(cls, atoms: ase.Atoms):
def from_atoms(
cls,
atoms: ase.Atoms,
):
arrays = deepcopy(atoms.arrays)
info = deepcopy(atoms.info)

Expand Down Expand Up @@ -158,7 +184,7 @@ def to_atoms(self) -> ase.Atoms:
return atoms

def to_dict(self, built_in_types: bool = True) -> dict:
data = attrs.asdict(self)
data = attrs.asdict(self, filter=lambda attr, _: attr.name != "recompute")
if built_in_types:
return data
else:
Expand Down

0 comments on commit 7121215

Please sign in to comment.