Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of Interchange.from_smirnoff on polymers #1122

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion openff/interchange/components/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
"""Utilities for processing and interfacing with the OpenFF Toolkit."""

from functools import lru_cache
from typing import TYPE_CHECKING, Union

import networkx
import numpy
from openff.toolkit import ForceField, Molecule, Quantity, Topology
from openff.toolkit.topology._mm_molecule import _SimpleMolecule
from openff.toolkit.typing.engines.smirnoff.parameters import VirtualSiteHandler
from openff.toolkit.typing.engines.smirnoff.parameters import ParameterHandler, VirtualSiteHandler
from openff.toolkit.utils.collections import ValidatedList
from openff.utilities.utilities import has_package

from openff.interchange.models import (
PotentialKey,
)

if has_package("openmm") or TYPE_CHECKING:
import openmm.app


_IDIVF_1 = Quantity(1.0, "dimensionless")
_PERIODICITIES = {
1: Quantity(1, "dimensionless"),
2: Quantity(2, "dimensionless"),
3: Quantity(3, "dimensionless"),
4: Quantity(4, "dimensionless"),
5: Quantity(5, "dimensionless"),
6: Quantity(6, "dimensionless"),
}


def _get_num_h_bonds(topology: "Topology") -> int:
"""Get the number of (covalent) bonds containing a hydrogen atom."""
n_bonds_containing_hydrogen = 0
Expand Down Expand Up @@ -202,3 +218,42 @@ def _lookup_virtual_site_parameter(
raise ValueError(
f"No VirtualSiteType found with {smirks=}, name={name=}, and match={match=}.",
)


@lru_cache
def _cache_angle_parameter_lookup(
potential_key: PotentialKey,
parameter_handler: ParameterHandler,
) -> dict[str, Quantity]:
parameter = parameter_handler.parameters[potential_key.id]

return {parameter_name: getattr(parameter, parameter_name) for parameter_name in ["k", "angle"]}


@lru_cache
def _cache_torsion_parameter_lookup(
potential_key: PotentialKey,
parameter_handler: ParameterHandler,
idivf: float | None = None,
) -> dict[str, Quantity]:
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

if idivf is not None:
# case of non-standard default_idivf in impropers
_idivf = idivf
elif parameter.idivf is None:
# This appears to only come from imports
_idivf = _IDIVF_1
elif parameter.idivf[n] == 1.0:
_idivf = _IDIVF_1
else:
_idivf = Quantity(parameter.idivf[n], "dimensionless")

return {
"k": parameter.k[n],
"periodicity": _PERIODICITIES[parameter.periodicity[n]],
"phase": parameter.phase[n],
"idivf": _idivf,
}
4 changes: 1 addition & 3 deletions openff/interchange/interop/openmm/_import/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ def from_openmm(
def _convert_constraints(
system: "openmm.System",
) -> ConstraintCollection | None:
from openff.toolkit import unit

from openff.interchange.components.potentials import Potential
from openff.interchange.models import BondKey, PotentialKey

Expand Down Expand Up @@ -162,7 +160,7 @@ def _convert_constraints(
potential_key = PotentialKey(id=f"Constraint{index}")
_keys[distance] = potential_key
constraints.potentials[potential_key] = Potential(
parameters={"distance": distance * unit.nanometer},
parameters={"distance": Quantity(distance, "nanometer")},
)

for index in range(system.getNumConstraints()):
Expand Down
2 changes: 0 additions & 2 deletions openff/interchange/smirnoff/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,6 @@ def _find_reference_matches(
@classmethod
def _assign_charges_from_molecules(
cls,
topology: Topology,
unique_molecule: Molecule,
molecules_with_preset_charges=list[Molecule] | None,
) -> tuple[bool, dict, dict]:
Expand Down Expand Up @@ -879,7 +878,6 @@ def store_matches(
unique_molecule = topology.molecule(unique_molecule_index)

flag, matches, potentials = self._assign_charges_from_molecules(
topology,
unique_molecule,
molecules_with_preset_charges,
)
Expand Down
47 changes: 22 additions & 25 deletions openff/interchange/smirnoff/_valence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ProperTorsionCollection,
)
from openff.interchange.components.potentials import Potential, WrappedPotential
from openff.interchange.components.toolkit import _cache_angle_parameter_lookup, _cache_torsion_parameter_lookup
from openff.interchange.exceptions import (
DuplicateMoleculeError,
InvalidParameterHandlerError,
Expand Down Expand Up @@ -448,15 +449,16 @@ def store_potentials(self, parameter_handler: AngleHandler) -> None:

"""
for potential_key in self.key_map.values():
smirks = potential_key.id
parameter = parameter_handler.parameters[smirks]
potential = Potential(
parameters={
parameter_name: getattr(parameter, parameter_name)
for parameter_name in self.potential_parameters()
self.potentials.update(
{
potential_key: Potential(
parameters=_cache_angle_parameter_lookup(
potential_key,
parameter_handler,
),
),
},
)
self.potentials[potential_key] = potential


class SMIRNOFFProperTorsionCollection(SMIRNOFFCollection, ProperTorsionCollection):
Expand Down Expand Up @@ -550,11 +552,11 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None:

"""
for topology_key, potential_key in self.key_map.items():
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

if topology_key.bond_order: # type: ignore[union-attr]
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

bond_order = topology_key.bond_order # type: ignore[union-attr]
data = parameter.k_bondorder[n]
coeffs = _get_interpolation_coeffs(
Expand All @@ -580,12 +582,7 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None:
{pot: coeff for pot, coeff in zip(pots, coeffs)},
)
else:
parameters = {
"k": parameter.k[n],
"periodicity": parameter.periodicity[n] * unit.dimensionless,
"phase": parameter.phase[n],
"idivf": parameter.idivf[n] * unit.dimensionless,
}
parameters = _cache_torsion_parameter_lookup(potential_key, parameter_handler)
potential = Potential(parameters=parameters) # type: ignore[assignment]
self.potentials[potential_key] = potential

Expand Down Expand Up @@ -734,11 +731,11 @@ def store_potentials(self, parameter_handler: ImproperTorsionHandler) -> None:
# Assumed to be a numerical value
idivf = _default_idivf * unit.dimensionless

parameters = {
"k": parameter.k[n],
"periodicity": parameter.periodicity[n] * unit.dimensionless,
"phase": parameter.phase[n],
"idivf": idivf,
}
potential = Potential(parameters=parameters)
self.potentials[potential_key] = potential
# parameter keys happen to be the same as keys in proper torsions
self.potentials[potential_key] = Potential(
parameters=_cache_torsion_parameter_lookup(
potential_key,
parameter_handler,
idivf=idivf,
),
)
Loading