Skip to content

Commit

Permalink
Merge pull request #255 from zincware/mixcalc
Browse files Browse the repository at this point in the history
MixCalculator
  • Loading branch information
M-R-Schaefer authored Jan 19, 2024
2 parents 87e75e7 + 329b451 commit 5569ed4
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 5 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime
RUN conda install git
RUN git clone https://github.com/zincware/ipsuite

COPY . /workspace/ipsuite

WORKDIR /workspace/ipsuite
RUN pip install .[comparison,gap,nequip,apax,allegro,mace]
RUN pip install --upgrade torch --extra-index-url https://download.pytorch.org/whl/cu116
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install git+https://github.com/PythonFZ/torch-dftd.git@patch-2
RUN pip install dvc-s3

COPY entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
Expand Down
2 changes: 2 additions & 0 deletions ipsuite/calculators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from .ase_md import (
from .ase_standard import EMTSinglePoint, LJSinglePoint
from .cp2k import CP2KSinglePoint, CP2KYaml
from .lammps import LammpsSimulator
from .mix import MixCalculator
from .orca import OrcaSinglePoint
from .torch_d3 import TorchD3
from .xtb import xTBSinglePoint
Expand All @@ -40,4 +41,5 @@ __all__ = [
"LammpsSimulator",
"TorchD3",
"FixedLayerConstraint",
"MixCalculator",
]
125 changes: 125 additions & 0 deletions ipsuite/calculators/mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import contextlib
import typing

import tqdm
import zntrack
from ase.calculators.calculator import (
Calculator,
PropertyNotImplementedError,
all_changes,
)

from ipsuite import base
from ipsuite.utils.ase_sim import freeze_copy_atoms


def _update_if_exists(results, key, atoms_list, func, mean: bool):
with contextlib.suppress(PropertyNotImplementedError):
value = sum(func(x) for x in atoms_list)
if mean and len(atoms_list) > 0:
value /= len(atoms_list)

if key in results:
results[key] += value
else:
results[key] = value


class _MixCalculator(Calculator):
def __init__(self, calculators: typing.List[Calculator], methods: list, **kwargs):
Calculator.__init__(self, **kwargs)
self.calculators = calculators
self.implemented_properties = self.calculators[0].implemented_properties
self.methods = methods

def calculate(
self,
atoms=None,
properties=None,
system_changes=all_changes,
):
if properties is None:
properties = self.implemented_properties

Calculator.calculate(self, atoms, properties, system_changes)

mean_results = []
sum_results = []

for i, calc in enumerate(self.calculators):
_atoms = atoms.copy()
_atoms.calc = calc
if self.methods[i] == "mean":
mean_results.append(_atoms)
elif self.methods[i] == "sum":
sum_results.append(_atoms)
else:
raise NotImplementedError

_update_if_exists(
self.results, "energy", mean_results, lambda x: x.get_potential_energy(), True
)
_update_if_exists(
self.results, "forces", mean_results, lambda x: x.get_forces(), True
)
_update_if_exists(
self.results, "stress", mean_results, lambda x: x.get_stress(), True
)

_update_if_exists(
self.results, "energy", sum_results, lambda x: x.get_potential_energy(), False
)
_update_if_exists(
self.results, "forces", sum_results, lambda x: x.get_forces(), False
)
_update_if_exists(
self.results, "stress", sum_results, lambda x: x.get_stress(), False
)


class CalculatorNode(typing.Protocol):
def get_calculator(self) -> typing.Type[Calculator]: ...


class MixCalculator(base.ProcessAtoms):
"""Combine multiple models or calculators into one.
Attributes:
calculators: list[CalculatorNode]
List of calculators to combine.
methods: str|list[str]
choose from ['mean', 'sum'] either for all calculators
as a string or for each calculator individually as a list.
All calculators that are assigned with 'mean' will be
computed first, then the calculators assigned with 'sum'
will be added.
"""

calculators: typing.List[CalculatorNode] = zntrack.deps()
methods: str | typing.List[str] = zntrack.params("sum")
# weights: list = zntrack.params(None) ?

def run(self) -> None:
calc = self.get_calculator()
self.atoms = []
for atoms in tqdm.tqdm(self.get_data(), ncols=70):
atoms.calc = calc
atoms.get_potential_energy()
self.atoms.append(freeze_copy_atoms(atoms))

def get_calculator(self, **kwargs) -> Calculator:
"""Property to return a model specific ase calculator object.
Returns
-------
calc:
ase calculator object
"""
if isinstance(self.methods, str):
methods = [self.methods] * len(self.calculators)
else:
methods = self.methods
return _MixCalculator(
calculators=[x.get_calculator(**kwargs) for x in self.calculators],
methods=methods,
)
5 changes: 1 addition & 4 deletions ipsuite/models/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing
from uuid import uuid4

import ase
import numpy as np
Expand Down Expand Up @@ -54,10 +53,8 @@ def calculate(
class EnsembleModel(base.IPSNode):
models: typing.List[MLModel] = zntrack.deps()

uuid = zntrack.zn.outs() # to connect this Node to other Nodes it requires an output.

def run(self) -> None:
self.uuid = str(uuid4())
pass

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Expand Down
1 change: 1 addition & 0 deletions ipsuite/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class _Nodes:
OrcaSinglePoint = "ipsuite.calculators.OrcaSinglePoint"
ApaxJaxMD = "ipsuite.calculators.ApaxJaxMD"
LammpsSimulator = "ipsuite.calculators.LammpsSimulator"
MixCalculator = "ipsuite.calculators.MixCalculator"

LangevinThermostat = "ipsuite.calculators.LangevinThermostat"
NPTThermostat = "ipsuite.calculators.NPTThermostat"
Expand Down
84 changes: 84 additions & 0 deletions tests/integration/calculators/test_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy.testing as npt

import ipsuite as ips


def test_mix_calculators(proj_path, traj_file):
with ips.Project(automatic_node_names=True) as proj:
data = ips.AddData(traj_file)
lj1 = ips.calculators.LJSinglePoint(data=data.atoms)
lj2 = ips.calculators.LJSinglePoint(data=data.atoms)
lj3 = ips.calculators.LJSinglePoint(data=data.atoms)

mix1 = ips.calculators.MixCalculator(
data=data.atoms,
calculators=[lj1, lj2],
methods="mean",
)

mix2 = ips.calculators.MixCalculator(
data=data.atoms,
calculators=[lj1, lj2],
methods="sum",
)

mix3 = ips.calculators.MixCalculator(
data=data.atoms,
calculators=[lj1, lj2, lj3],
methods=["mean", "sum", "mean"],
)

proj.run()

lj1.load()
mix1.load()

for a, b in zip(lj1.atoms, mix1.atoms):
assert a.get_potential_energy() == b.get_potential_energy()
npt.assert_almost_equal(a.get_forces(), b.get_forces())

lj2.load()
mix2.load()

for a, b, c in zip(lj1.atoms, lj2.atoms, mix2.atoms):
assert (
a.get_potential_energy() + b.get_potential_energy()
== c.get_potential_energy()
)
npt.assert_almost_equal(a.get_forces() + b.get_forces(), c.get_forces())

lj3.load()
mix3.load()

for a, b, c, d in zip(lj1.atoms, lj2.atoms, lj3.atoms, mix3.atoms):

# (a + c / 2) + b
true_energy = a.get_potential_energy() + b.get_potential_energy()
true_forces = a.get_forces() + b.get_forces()

assert true_energy == d.get_potential_energy()
npt.assert_almost_equal(true_forces, d.get_forces())


def test_mix_calculator_external(proj_path, traj_file):
lj1 = ips.calculators.LJSinglePoint(data=None)
lj2 = ips.calculators.LJSinglePoint(data=None)

with ips.Project(automatic_node_names=True) as proj:
data = ips.AddData(traj_file)
lj3 = ips.calculators.LJSinglePoint(data=data.atoms)

mix1 = ips.calculators.MixCalculator(
data=data.atoms,
calculators=[lj1, lj2],
methods="mean",
)

proj.run()

lj3.load()
mix1.load()

for a, b in zip(lj3.atoms, mix1.atoms):
assert a.get_potential_energy() == b.get_potential_energy()
npt.assert_almost_equal(a.get_forces(), b.get_forces())

0 comments on commit 5569ed4

Please sign in to comment.