Skip to content

Commit

Permalink
fix bug where output concatenation in batch_apply fails with states w…
Browse files Browse the repository at this point in the history
…ith different numbers of atoms
  • Loading branch information
svandenhaute committed Dec 2, 2024
1 parent 0dfbb35 commit 0153da8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
21 changes: 20 additions & 1 deletion psiflow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,33 @@ def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]:
Note:
This function is wrapped as a Parsl app and executed using the default_threads executor.
"""
def pad_arrays(
arrays: list[np.ndarray],
pad_dimension: int = 1,
) -> list[np.ndarray]:
ndims = np.array([len(a.shape) for a in arrays])
assert np.all(ndims == ndims[0])
assert np.all(pad_dimension < ndims)

pad_size = max([a.shape[pad_dimension] for a in arrays])
for i in range(len(arrays)):
shape = list(arrays[i].shape)
shape[pad_dimension] = pad_size - shape[pad_dimension]
padding = np.zeros(tuple(shape)) + np.nan
arrays[i] = np.concatenate((arrays[i], padding), axis=pad_dimension)
return arrays

narrays = len(args[0])
for arg in args:
assert isinstance(arg, list)
assert all([len(a) == narrays for a in args])

concatenated = []
for i in range(narrays):
concatenated.append(np.concatenate([arg[i] for arg in args]))
arrays = [arg[i] for arg in args]
if len(arrays[0].shape) > 1:
pad_arrays(arrays)
concatenated.append(np.concatenate(tuple(arrays)))
return concatenated


Expand Down
7 changes: 7 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from parsl.data_provider.files import File # type: ignore

import psiflow
from psiflow.data import Dataset
from psiflow.functions import (
EinsteinCrystalFunction,
HarmonicFunction,
Expand Down Expand Up @@ -254,6 +255,12 @@ def test_hamiltonian_arithmetic(dataset):
assert hamiltonian == hamiltonian + zero
assert 2 * hamiltonian + zero == 2 * hamiltonian

geometries = [dataset[i].result() for i in [0, -1]]
natoms = [len(geometry) for geometry in geometries]
forces = zero.compute(geometries, 'forces', batch_size=1).result()
assert np.all(forces[0, :natoms[0]] == 0.0)
assert np.all(forces[-1, :natoms[1]] == 0.0)


def test_subtract(dataset):
einstein = EinsteinCrystal(dataset[0], force_constant=1.0)
Expand Down

0 comments on commit 0153da8

Please sign in to comment.