Skip to content

Commit

Permalink
reunited hamiltonian and reference compute interface
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Dec 19, 2024
1 parent 0d33c73 commit 8b8d41f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
17 changes: 4 additions & 13 deletions psiflow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def get_length(arg):
def compute(
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
*apply_apps: Union[PythonApp, Callable],
outputs_: Union[str, list[str], None] = None,
outputs_: Union[str, list[str], tuple[str, ...], None] = None,
reduce_func: Union[PythonApp, Callable] = aggregate_multiple,
batch_size: Optional[int] = None,
) -> Union[list[AppFuture], AppFuture]:
Expand All @@ -570,7 +570,7 @@ def compute(
Returns:
Union[list[AppFuture], AppFuture]: Future(s) representing computation results.
"""
if outputs_ is not None and not isinstance(outputs_, list):
if type(outputs_) is str:
outputs_ = [outputs_]
if batch_size is not None:
if isinstance(arg, Dataset):
Expand Down Expand Up @@ -627,7 +627,7 @@ class Computable:
def compute(
self,
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
outputs: Union[str, list[str], None] = None,
*outputs: Optional[str],
batch_size: Optional[int] = -1, # if -1: take class default
) -> Union[list[AppFuture], AppFuture]:
"""
Expand All @@ -641,13 +641,4 @@ def compute(
Returns:
Union[list[AppFuture], AppFuture]: Future(s) representing computation results.
"""
if outputs is None:
outputs = list(self.__class__.outputs)
if batch_size == -1:
batch_size = self.__class__.batch_size
return compute(
arg,
self.app,
outputs_=outputs,
batch_size=batch_size,
)
raise NotImplementedError
17 changes: 17 additions & 0 deletions psiflow/hamiltonians.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ class Hamiltonian(Computable):
outputs: ClassVar[tuple] = ("energy", "forces", "stress")
batch_size = 1000

def compute(
self,
arg: Union[Dataset, AppFuture[list], list, AppFuture, Geometry],
*outputs: Optional[str],
batch_size: Optional[int] = -1, # if -1: take class default
) -> Union[list[AppFuture], AppFuture]:
if len(outputs) == 0:
outputs = tuple(self.__class__.outputs)
if batch_size == -1:
batch_size = self.__class__.batch_size
return compute(
arg,
self.app,
outputs_=outputs,
batch_size=batch_size,
)

def __eq__(self, hamiltonian: Hamiltonian) -> bool:
raise NotImplementedError

Expand Down
6 changes: 3 additions & 3 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_einstein_crystal(dataset):
hamiltonian = EinsteinCrystal(dataset[0], force_constant=1.0)

forces_, stress_, energy_ = hamiltonian.compute(
dataset[:4], outputs=["forces", "stress", "energy"]
dataset[:4], "forces", "stress", "energy"
)
assert np.allclose(
energy_.result(),
Expand All @@ -56,7 +56,7 @@ def test_einstein_crystal(dataset):
forces,
)

forces = hamiltonian.compute(dataset[:4], outputs=["forces"], batch_size=3)
forces = hamiltonian.compute(dataset[:4], "forces", batch_size=3)
assert np.allclose(
forces.result(),
forces_.result(),
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_plumed_function(tmp_path, dataset, dataset_h2):
distance = np.linalg.norm(positions[:, 0, :] - positions[:, 1, :], axis=1)
distance = distance.reshape(-1, 1)

energy = hamiltonian.compute(dataset[:10], ["energy"]).result()
energy = hamiltonian.compute(dataset[:10], "energy").result()

sigma = 2 * np.ones((1, 2))
height = np.array([70, 70]).reshape(1, -1) * (kJ / mol) # unit consistency
Expand Down

0 comments on commit 8b8d41f

Please sign in to comment.