Skip to content

Commit

Permalink
add more types
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Dec 23, 2024
1 parent 45b7600 commit c12fa70
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions src/pymatgen/analysis/pourbaix_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from typing import Any, Literal

import matplotlib.pyplot as plt
from numpy.typing import NDArray
from typing_extensions import Self

from pymatgen.core import DummySpecies, Species
Expand Down Expand Up @@ -526,15 +527,15 @@ def __init__(

self._stable_domains, self._stable_domain_vertices = self.get_pourbaix_domains(self._processed_entries)

def _convert_entries_to_points(self, pourbaix_entries):
def _convert_entries_to_points(self, pourbaix_entries: list[PourbaixEntry]) -> NDArray:
"""
Args:
pourbaix_entries ([PourbaixEntry]): list of Pourbaix entries
pourbaix_entries (list[PourbaixEntry]): Pourbaix entries
to process into vectors in nph-nphi-composition space.
Returns:
list of vectors, [[nph, nphi, e0, x1, x2, ..., xn-1]]
corresponding to each entry in nph-nphi-composition space
NDAarray: vectors as [[nph, nphi, e0, x1, x2, ..., xn-1]]
corresponding to each entry in nph-nphi-composition space
"""
vecs = [
[entry.npH, entry.nPhi, entry.energy] + [entry.composition.get(elt) for elt in self.pbx_elts[:-1]]
Expand All @@ -545,15 +546,18 @@ def _convert_entries_to_points(self, pourbaix_entries):
vecs *= norms
return vecs

def _get_hull_in_nph_nphi_space(self, entries) -> tuple[list[PourbaixEntry], list[Simplex]]:
def _get_hull_in_nph_nphi_space(
self,
entries: list[PourbaixEntry],
) -> tuple[list[PourbaixEntry], list[Simplex]]:
"""Generate convex hull of Pourbaix diagram entries in composition,
npH, and nphi space. This enables filtering of multi-entries
such that only compositionally stable combinations of entries
are included.
Args:
entries ([PourbaixEntry]): list of PourbaixEntries to construct
the convex hull
entries (list[PourbaixEntry]): PourbaixEntries to construct
the convex hull.
Returns:
tuple[list[PourbaixEntry], list[Simplex]]: PourbaixEntry list and stable
Expand Down Expand Up @@ -602,11 +606,15 @@ def _get_hull_in_nph_nphi_space(self, entries) -> tuple[list[PourbaixEntry], lis

return min_entries, valid_facets

def _preprocess_pourbaix_entries(self, entries, nproc=None):
def _preprocess_pourbaix_entries(
self,
entries: list[PourbaixEntry],
nproc: int | None = None,
) -> list[MultiEntry]:
"""Generate multi-entries for Pourbaix diagram.
Args:
entries ([PourbaixEntry]): list of PourbaixEntries to preprocess
entries (list[PourbaixEntry]): PourbaixEntries to preprocess
into MultiEntries
nproc (int): number of processes to be used in parallel
treatment of entry combos
Expand Down Expand Up @@ -651,7 +659,11 @@ def _preprocess_pourbaix_entries(self, entries, nproc=None):

return multi_entries

def _generate_multielement_entries(self, entries, nproc=None):
def _generate_multielement_entries(
self,
entries: list[PourbaixEntry],
nproc: int | None = None,
) -> list[MultiEntry]:
"""
Create entries for multi-element Pourbaix construction.
Expand Down Expand Up @@ -698,7 +710,9 @@ def _generate_multielement_entries(self, entries, nproc=None):
return processed_entries

@staticmethod
def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4):
def process_multientry(
entry_list: list, prod_comp: Composition, coeff_threshold: float = 1e-4
) -> MultiEntry | None:
"""Static method for finding a multientry based on
a list of entries and a product composition.
Essentially checks to see if a valid aqueous
Expand Down Expand Up @@ -736,7 +750,10 @@ def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4):
return None

@staticmethod
def get_pourbaix_domains(pourbaix_entries, limits=None):
def get_pourbaix_domains(
pourbaix_entries: list[PourbaixEntry],
limits: list[list[float]] | None = None,
) -> tuple[dict, dict]:
"""Get a set of Pourbaix stable domains (i.e. polygons) in
pH-V space from a list of pourbaix_entries.
Expand All @@ -750,12 +767,12 @@ def get_pourbaix_domains(pourbaix_entries, limits=None):
points.
Args:
pourbaix_entries ([PourbaixEntry]): Pourbaix entries
pourbaix_entries (list[PourbaixEntry]): Pourbaix entries
with which to construct stable Pourbaix domains
limits ([[float]]): limits in which to do the pourbaix
limits (list[list[float]]): limits in which to do the pourbaix
analysis
Returns:
Returns: # TODO: incorrect return type doc
Returns a dict of the form {entry: [boundary_points]}.
The list of boundary points are the sides of the N-1
dim polytope bounding the allowable ph-V range of each entry.
Expand Down Expand Up @@ -817,7 +834,7 @@ def get_pourbaix_domains(pourbaix_entries, limits=None):

return pourbaix_domains, pourbaix_domain_vertices

def find_stable_entry(self, pH, V):
def find_stable_entry(self, pH: float, V: float) -> PourbaixEntry:
"""Find stable entry at a pH,V condition.
Args:
Expand All @@ -830,7 +847,12 @@ def find_stable_entry(self, pH, V):
energies_at_conditions = [entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries]
return self.stable_entries[np.argmin(energies_at_conditions)]

def get_decomposition_energy(self, entry, pH, V):
def get_decomposition_energy(
self,
entry: PourbaixEntry,
pH: float | list[float],
V: float | list[float],
) -> NDArray:
"""Find decomposition to most stable entries in eV/atom,
supports vectorized inputs for pH and V.
Expand Down Expand Up @@ -860,7 +882,7 @@ def get_decomposition_energy(self, entry, pH, V):
decomposition_energy /= entry.composition.num_atoms
return decomposition_energy

def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> np.ndarray:
def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> NDArray:
"""Get the minimum energy of the Pourbaix "basin" that is formed
from the stable Pourbaix planes. Vectorized.
Expand All @@ -874,7 +896,7 @@ def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> np
all_gs = np.array([entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries])
return np.min(all_gs, axis=0)

def get_stable_entry(self, pH, V):
def get_stable_entry(self, pH: float, V: float) -> PourbaixEntry | MultiEntry:
"""Get the stable entry at a given pH, V condition.
Args:
Expand All @@ -889,26 +911,26 @@ def get_stable_entry(self, pH, V):
return self.stable_entries[np.argmin(all_gs)]

@property
def stable_entries(self):
def stable_entries(self) -> list:
"""The stable entries in the Pourbaix diagram."""
return list(self._stable_domains)

@property
def unstable_entries(self):
def unstable_entries(self) -> list:
"""All unstable entries in the Pourbaix diagram."""
return [entry for entry in self.all_entries if entry not in self.stable_entries]

@property
def all_entries(self):
def all_entries(self) -> list:
"""All entries used to generate the Pourbaix diagram."""
return self._processed_entries

@property
def unprocessed_entries(self):
def unprocessed_entries(self) -> list:
"""Unprocessed entries."""
return self._unprocessed_entries

def as_dict(self):
def as_dict(self) -> dict[str, Any]:
"""Get MSONable dict."""
return {
"@module": type(self).__module__,
Expand Down Expand Up @@ -940,7 +962,7 @@ def from_dict(cls, dct: dict) -> Self:
class PourbaixPlotter:
"""A plotter class for phase diagrams."""

def __init__(self, pourbaix_diagram):
def __init__(self, pourbaix_diagram: PourbaixDiagram) -> None:
"""
Args:
pourbaix_diagram (PourbaixDiagram): A PourbaixDiagram object.
Expand Down Expand Up @@ -1046,12 +1068,12 @@ def plot_entry_stability(
Args:
entry (Any): The entry to plot stability for.
pH_range (tuple[float, float], optional): pH range for the plot. Defaults to (-2, 16).
pH_resolution (int, optional): pH resolution. Defaults to 100.
V_range (tuple[float, float], optional): Voltage range for the plot. Defaults to (-3, 3).
V_resolution (int, optional): Voltage resolution. Defaults to 100.
e_hull_max (float, optional): Maximum energy above the hull. Defaults to 1.
cmap (str, optional): Colormap for the plot. Defaults to "RdYlBu_r".
pH_range (tuple[float, float]): pH range for the plot. Defaults to (-2, 16).
pH_resolution (int): pH resolution. Defaults to 100.
V_range (tuple[float, float]): Voltage range for the plot. Defaults to (-3, 3).
V_resolution (int): Voltage resolution. Defaults to 100.
e_hull_max (float): Maximum energy above the hull. Defaults to 1.
cmap (str): Colormap for the plot. Defaults to "RdYlBu_r".
ax (Axes, optional): Existing matplotlib Axes object for plotting. Defaults to None.
**kwargs (Any): Additional keyword arguments passed to `get_pourbaix_plot`.
Expand Down Expand Up @@ -1079,7 +1101,7 @@ def plot_entry_stability(

return ax

def domain_vertices(self, entry):
def domain_vertices(self, entry) -> list:
"""Get the vertices of the Pourbaix domain.
Args:
Expand All @@ -1091,7 +1113,7 @@ def domain_vertices(self, entry):
return self._pbx._stable_domain_vertices[entry]


def generate_entry_label(entry):
def generate_entry_label(entry: PourbaixEntry | MultiEntry) -> str:
"""
Generates a label for the Pourbaix plotter.
Expand Down

0 comments on commit c12fa70

Please sign in to comment.