Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring type-checking back online #996

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
MAINT: Update annotations
  • Loading branch information
mattwthompson committed Jun 21, 2024
commit 6d2dcd6551eadb8acfa6aced7ca73e56f3df27ff
1 change: 0 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ jobs:
python devtools/scripts/molecule-regressions.py

- name: Run mypy
continue-on-error: true
if: ${{ matrix.python-version == '3.11' }}
run: |
# As of 01/23, JAX with mypy is too slow to use without a pre-built cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def test_tip5p_num_exceptions(self, water, tip5p, combine, n_molecules):
# Safeguard against some of the behavior seen in #919
for index in range(num_exceptions):
p1, p2, *_ = force.getExceptionParameters(index)
print(p1, p2)

if sorted([p1, p2]) == [0, 3]:
raise Exception(
Expand Down
29 changes: 15 additions & 14 deletions openff/interchange/components/_packmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _check_add_positive_mass(mass_to_add):
)


def _check_box_shape_shape(box_shape: ArrayLike):
def _check_box_shape_shape(box_shape: NDArray):
"""Check the .shape of the box_shape argument."""
if box_shape.shape != (3, 3):
raise PACKMOLValueError(
Expand Down Expand Up @@ -531,27 +531,28 @@ def _build_input_file(


def _center_topology_at(
center_solute: bool | Literal["BOX_VECS", "ORIGIN", "BRICK"],
center_solute: Literal["NO", "YES", "BOX_VECS", "ORIGIN", "BRICK"],
topology: Topology,
box_vectors: Quantity,
brick_size: Quantity,
) -> Topology:
"""Return a copy of the topology centered as requested."""
if isinstance(center_solute, str):
center_solute = center_solute.upper()
_center_solute = center_solute.upper()

topology = Topology(topology)

if center_solute is False:
if _center_solute == "NO":
return topology
elif center_solute in [True, "BOX_VECS"]:
elif _center_solute in ["YES", "BOX_VECS"]:
new_center = box_vectors.sum(axis=0) / 2.0
elif center_solute == "ORIGIN":
elif _center_solute == "ORIGIN":
new_center = numpy.zeros(3)
elif center_solute == "BRICK":
elif _center_solute == "BRICK":
new_center = brick_size / 2.0
else:
PACKMOLValueError(
f"center_solute must be a bool, 'BOX_VECS', 'ORIGIN', or 'BRICK', not {center_solute!r}",
"center_solute must be 'NO', 'YES', 'BOX_VECS', 'ORIGIN', or 'BRICK', "
f"not {center_solute!r}",
)

positions = topology.get_positions()
Expand All @@ -569,7 +570,7 @@ def pack_box(
box_vectors: Quantity | None = None,
mass_density: Quantity | None = None,
box_shape: ArrayLike = RHOMBIC_DODECAHEDRON,
center_solute: bool | Literal["BOX_VECS", "ORIGIN", "BRICK"] = False,
center_solute: Literal["NO", "YES", "BOX_VECS", "ORIGIN", "BRICK"] = "NO",
working_directory: str | None = None,
retain_working_files: bool = False,
) -> Topology:
Expand Down Expand Up @@ -609,12 +610,12 @@ def pack_box(
<http://docs.openmm.org/latest/userguide/theory/
05_other_features.html#periodic-boundary-conditions>`_.
center_solute
How to center ``solute`` in the simulation box. If ``True``
How to center ``solute`` in the simulation box. If ``"YES"``
or ``"box_vecs"``, the solute's center of geometry will be placed at
the center of the box's parallelopiped representation. If ``"origin"``,
the solute will centered at the origin. If ``"brick"``, the solute will
be centered in the box's rectangular brick representation. If
``False`` (the default), the solute will not be moved.
``"NO"`` (the default), the solute will not be moved.
working_directory: str, optional
The directory in which to generate the temporary working files. If
``None``, a temporary one will be created.
Expand Down Expand Up @@ -678,7 +679,7 @@ def pack_box(
brick_size = _compute_brick_from_box_vectors(box_vectors)

# Center the solute
if center_solute and solute is not None:
if center_solute != "NO" and solute is not None:
solute = _center_topology_at(
center_solute,
solute,
Expand Down Expand Up @@ -956,5 +957,5 @@ def solvate_topology_nonwater(
solute=topology,
tolerance=tolerance,
box_vectors=box_vectors,
center_solute=True,
center_solute="YES",
)
27 changes: 13 additions & 14 deletions openff/interchange/components/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Annotated, Any, Union

import numpy
from numpy.typing import NDArray
from openff.toolkit import Quantity
from openff.utilities.utilities import has_package, requires_package
from pydantic import (
Expand All @@ -28,12 +29,10 @@
from openff.interchange.warnings import InterchangeDeprecationWarning

if has_package("jax"):
from jax import numpy as jax_numpy

from numpy.typing import ArrayLike

if has_package("jax"):
# JAX stubs seem very broken, not adding this to type annotations
# even though many should be NDArray | Array
from jax import Array
from jax import numpy as jax_numpy


def __getattr__(name: str):
Expand Down Expand Up @@ -296,7 +295,7 @@ def _get_parameters(self, atom_indices: tuple[int]) -> dict:
def get_force_field_parameters(
self,
use_jax: bool = False,
) -> Union["ArrayLike", "Array"]:
) -> NDArray:
"""Return a flattened representation of the force field parameters."""
# TODO: Handle WrappedPotential
if any(
Expand All @@ -320,20 +319,20 @@ def get_force_field_parameters(
],
)

def set_force_field_parameters(self, new_p: "ArrayLike") -> None:
def set_force_field_parameters(self, new_p: NDArray) -> None:
"""Set the force field parameters from a flattened representation."""
mapping = self.get_mapping()
if new_p.shape[0] != len(mapping): # type: ignore
if new_p.shape[0] != len(mapping):
raise RuntimeError

for potential_key, potential_index in self.get_mapping().items():
potential = self.potentials[potential_key]
if len(new_p[potential_index, :]) != len(potential.parameters): # type: ignore
if len(new_p[potential_index, :]) != len(potential.parameters):
raise RuntimeError

for parameter_index, parameter_key in enumerate(potential.parameters):
parameter_units = potential.parameters[parameter_key].units
modified_parameter = new_p[potential_index, parameter_index] # type: ignore
modified_parameter = new_p[potential_index, parameter_index]

self.potentials[potential_key].parameters[parameter_key] = (
modified_parameter * parameter_units
Expand All @@ -343,7 +342,7 @@ def get_system_parameters(
self,
p=None,
use_jax: bool = False,
) -> Union["ArrayLike", "Array"]:
) -> NDArray:
"""
Return a flattened representation of system parameters.

Expand Down Expand Up @@ -385,7 +384,7 @@ def parametrize(
self,
p=None,
use_jax: bool = True,
) -> Union["ArrayLike", "Array"]:
) -> NDArray:
"""Return an array of system parameters, given an array of force field parameters."""
if p is None:
p = self.get_force_field_parameters(use_jax=use_jax)
Expand All @@ -402,7 +401,7 @@ def parametrize_partial(self):
)

@requires_package("jax")
def get_param_matrix(self) -> Union["Array", "ArrayLike"]:
def get_param_matrix(self) -> Array:
"""Get a matrix representing the mapping between force field and system parameters."""
from functools import partial

Expand All @@ -417,7 +416,7 @@ def get_param_matrix(self) -> Union["Array", "ArrayLike"]:
jac_parametrize = jax.jacfwd(parametrize_partial)
jac_res = jac_parametrize(p)

return jac_res.reshape(-1, p.flatten().shape[0]) # type: ignore[union-attr]
return jac_res.reshape(-1, p.flatten().shape[0])

def __getattr__(self, attr: str):
if attr == "slot_map":
Expand Down
2 changes: 2 additions & 0 deletions openff/interchange/drivers/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _get_lammps_energies(
) -> dict[str, Quantity]:
import lammps

assert interchange.positions is not None

if round_positions is not None:
interchange.positions = numpy.round(interchange.positions, round_positions)

Expand Down
4 changes: 2 additions & 2 deletions openff/interchange/drivers/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def diff(
self["Electrostatics"] and other["Electrostatics"]
) is not None:
for key in ("vdW", "Electrostatics"):
energy_differences[key] = self[key] - other[key]
energy_differences[key] = self[key] - other[key]
energy_differences[key] = self[key] - other[key] # type: ignore[operator]
energy_differences[key] = self[key] - other[key] # type: ignore[operator]

nonbondeds_processed = True

Expand Down
15 changes: 11 additions & 4 deletions openff/interchange/foyer/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
FoyerRBImproperHandler,
FoyerRBProperHandler,
)
from openff.interchange.models import TopologyKey
from openff.interchange.models import LibraryChargeTopologyKey, TopologyKey

if has_package("foyer"):
from foyer.forcefield import Forcefield
Expand Down Expand Up @@ -53,7 +53,10 @@ def _create_interchange(

# This block is from a mega merge, unclear if it's still needed
for name, handler_class in get_handlers_callable().items():
interchange.collections[name] = handler_class()
interchange.collections[name] = handler_class(
type=name,
expression=f"FOYER_{name}",
)

vdw_handler = interchange["vdW"]
vdw_handler.scale_14 = force_field.lj14scale
Expand All @@ -76,7 +79,9 @@ def _create_interchange(

# TODO: Populate .mdconfig, but only after a reasonable number of state mutations have been tested

charges = electrostatics.charges
charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = (
electrostatics.charges
)

for molecule in interchange.topology.molecules:
molecule_charges = [
Expand All @@ -85,7 +90,9 @@ def _create_interchange(
].m
for atom in molecule.atoms
]
molecule.partial_charges = Quantity(

# Quantity(list[Quantity]) works ... but is a big magical to mypy
molecule.partial_charges = Quantity( # type: ignore[call-overload]
molecule_charges,
unit.elementary_charge,
)
Expand Down
10 changes: 8 additions & 2 deletions openff/interchange/foyer/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from openff.interchange.common._nonbonded import ElectrostaticsCollection, vdWCollection
from openff.interchange.components.potentials import Potential
from openff.interchange.foyer._base import _copy_params
from openff.interchange.models import PotentialKey, TopologyKey
from openff.interchange.models import (
LibraryChargeTopologyKey,
PotentialKey,
TopologyKey,
)

if has_package("foyer"):
from foyer.forcefield import Forcefield
Expand Down Expand Up @@ -60,7 +64,9 @@ class FoyerElectrostaticsHandler(ElectrostaticsCollection):
force_field_key: str = "atoms"
cutoff: _DistanceQuantity = 9.0 * unit.angstrom

_charges: dict[TopologyKey, Quantity] = PrivateAttr(default_factory=dict)
_charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = PrivateAttr(
default_factory=dict,
)

def store_charges(
self,
Expand Down
6 changes: 3 additions & 3 deletions openff/interchange/foyer/_valence.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def store_matches(
class FoyerRBImproperHandler(FoyerRBProperHandler):
"""Handler storing Ryckaert-Bellemans improper torsion potentials as produced by a Foyer force field."""

type: Literal["RBImpropers"] = "RBImpropers"
type: Literal["RBImpropers"] = "RBImpropers" # type: ignore[assignment]
connection_attribute: str = "impropers"


Expand All @@ -149,8 +149,8 @@ class FoyerPeriodicProperHandler(FoyerConnectedAtomsHandler, ProperTorsionCollec
force_field_key: str = "periodic_propers"
connection_attribute: str = "propers"
raise_on_missing_params: bool = False
type: str = "ProperTorsions"
expression: str = "k*(1+cos(periodicity*theta-phase))"
type: str = "ProperTorsions" # type: ignore[assignment]
expression: str = "k*(1+cos(periodicity*theta-phase))" # type: ignore[assignment]

def get_params_with_units(self, params):
"""Get the parameters of this handler, tagged with units."""
Expand Down
29 changes: 13 additions & 16 deletions openff/interchange/interop/_virtual_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from collections import defaultdict
from collections.abc import Iterable
from typing import DefaultDict

import numpy
Expand Down Expand Up @@ -34,7 +33,7 @@ def _virtual_site_parent_molecule_mapping(
A dictionary mapping virtual site keys to the index of the molecule they belong to.

"""
mapping = dict()
mapping: dict[VirtualSiteKey, int] = dict()

if "VirtualSites" not in interchange.collections:
return mapping
Expand Down Expand Up @@ -164,7 +163,7 @@ def get_positions_with_virtual_sites(

def _get_separation_by_atom_indices(
interchange: Interchange,
atom_indices: Iterable[int],
atom_indices: tuple[int, ...],
prioritize_geometry: bool = False,
) -> Quantity:
"""
Expand All @@ -175,48 +174,46 @@ def _get_separation_by_atom_indices(
This is slow, but often necessary for converting virtual site "distances" to weighted
averages (unitless) of orientation atom positions.
"""
assert interchange.positions is not None

if prioritize_geometry:
p1 = interchange.positions[atom_indices[1]]
p0 = interchange.positions[atom_indices[0]]

return p1 - p0

if "Constraints" in interchange.collections:
collection = interchange["Constraints"]
constraints = interchange["Constraints"]

for key in collection.key_map:
for key in constraints.key_map:
if (key.atom_indices == atom_indices) or (
key.atom_indices[::-1] == atom_indices
):
return collection.potentials[collection.key_map[key]].parameters[
return constraints.potentials[constraints.key_map[key]].parameters[
"distance"
]

if "Bonds" in interchange.collections:
collection = interchange["Bonds"]
bonds = interchange["Bonds"]

for key in collection.key_map:
for key in bonds.key_map:
if (key.atom_indices == atom_indices) or (
key.atom_indices[::-1] == atom_indices
):
return collection.potentials[collection.key_map[key]].parameters[
"length"
]
return bonds.potentials[bonds.key_map[key]].parameters["length"]

# Two heavy atoms may be on opposite ends of an angle, in which case it's still
# possible to determine their separation as defined by the geometry of the force field
if "Angles" in interchange.collections:
collection = interchange["Angles"]
angles = interchange["Angles"]

index0 = atom_indices[0]
index1 = atom_indices[1]
for key in collection.key_map:
for key in angles.key_map:
if (key.atom_indices[0] == index0 and key.atom_indices[2] == index1) or (
key.atom_indices[2] == index0 and key.atom_indices[0] == index1
):
gamma = collection.potentials[collection.key_map[key]].parameters[
"angle"
]
gamma = angles.potentials[angles.key_map[key]].parameters["angle"]

a = _get_separation_by_atom_indices(
interchange,
Expand Down
Loading
Loading