Skip to content

Commit

Permalink
Migrate some type annotation tweaks from #4100 (#4290)
Browse files Browse the repository at this point in the history
* add types for analysis.eos

* add requires decorator to is_valid_bibtex

* revert adding requires

* fix typo in core.structure.rotate_sites

* skip prototype test if pybtex is not available

* remove import guard from cherry pick
  • Loading branch information
DanielYang59 authored Feb 18, 2025
1 parent 3744b4a commit 5aff054
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 154 deletions.
97 changes: 61 additions & 36 deletions src/pymatgen/analysis/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig, pretty_plot

if TYPE_CHECKING:
from typing import ClassVar
from collections.abc import Sequence
from typing import Any, ClassVar

import matplotlib.pyplot as plt

Expand All @@ -40,7 +41,11 @@ class EOSBase(ABC):
implementations.
"""

def __init__(self, volumes, energies):
def __init__(
self,
volumes: Sequence[float],
energies: Sequence[float],
) -> None:
"""
Args:
volumes (Sequence[float]): in Ang^3.
Expand All @@ -50,18 +55,28 @@ def __init__(self, volumes, energies):
self.energies = np.array(energies)
# minimum energy(e0), buk modulus(b0),
# derivative of bulk modulus w.r.t. pressure(b1), minimum volume(v0)
self._params = None
self._params: Sequence | None = None
# the eos function parameters. It is the same as _params except for
# equation of states that uses polynomial fits(delta_factor and
# numerical_eos)
self.eos_params = None
self.eos_params: Sequence | None = None

def _initial_guess(self):
def __call__(self, volume: float) -> float:
"""
Args:
volume (float | list[float]): volume(s) in Ang^3.
Returns:
Compute EOS with this volume.
"""
return self.func(volume)

def _initial_guess(self) -> tuple[float, float, float, float]:
"""
Quadratic fit to get an initial guess for the parameters.
Returns:
tuple: 4 floats for (e0, b0, b1, v0)
tuple[float, float, float, float]: e0, b0, b1, v0
"""
a, b, c = np.polyfit(self.volumes, self.energies, 2)
self.eos_params = [a, b, c]
Expand All @@ -78,7 +93,7 @@ def _initial_guess(self):

return e0, b0, b1, v0

def fit(self):
def fit(self) -> None:
"""
Do the fitting. Does least square fitting. If you want to use custom
fitting, must override this.
Expand Down Expand Up @@ -120,24 +135,20 @@ def func(self, volume):
"""
return self._func(np.array(volume), self.eos_params)

def __call__(self, volume: float) -> float:
"""
Args:
volume (float | list[float]): volume(s) in Ang^3.
Returns:
Compute EOS with this volume.
"""
return self.func(volume)

@property
def e0(self) -> float:
"""The min energy."""
if self._params is None:
raise RuntimeError("params have not be initialized.")

return self._params[0]

@property
def b0(self) -> float:
"""The bulk modulus in units of energy/unit of volume^3."""
if self._params is None:
raise RuntimeError("params have not be initialized.")

return self._params[1]

@property
Expand All @@ -156,11 +167,18 @@ def v0(self):
return self._params[3]

@property
def results(self):
def results(self) -> dict[str, Any]:
"""A summary dict."""
return {"e0": self.e0, "b0": self.b0, "b1": self.b1, "v0": self.v0}

def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
def plot(
self,
width: float = 8,
height: float | None = None,
ax: plt.Axes = None,
dpi: float | None = None,
**kwargs,
) -> plt.Axes:
"""
Plot the equation of state.
Expand All @@ -170,7 +188,7 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
golden ratio.
ax (plt.Axes): If supplied, changes will be made to the existing Axes.
Otherwise, new Axes will be created.
dpi:
dpi (float): DPI.
kwargs (dict): additional args fed to pyplot.plot.
supported keys: style, color, text, label
Expand Down Expand Up @@ -211,16 +229,18 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
return ax

@add_fig_kwargs
def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
def plot_ax(
self,
ax: plt.Axes | None = None,
fontsize: float = 12,
**kwargs,
) -> plt.Figure:
"""
Plot the equation of state on axis `ax`.
Args:
ax: matplotlib Axes or None if a new figure should be created.
fontsize: Legend fontsize.
color (str): plot color.
label (str): Plot label
text (str): Legend text (options)
Returns:
plt.Figure: matplotlib figure.
Expand Down Expand Up @@ -270,7 +290,7 @@ def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
class Murnaghan(EOSBase):
"""Murnaghan EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""From PRB 28,5480 (1983)."""
e0, b0, b1, v0 = tuple(params)
return e0 + b0 * volume / b1 * (((v0 / volume) ** b1) / (b1 - 1.0) + 1.0) - v0 * b0 / (b1 - 1.0)
Expand All @@ -279,7 +299,7 @@ def _func(self, volume, params):
class Birch(EOSBase):
"""Birch EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""From Intermetallic compounds: Principles and Practice, Vol. I:
Principles Chapter 9 pages 195-210 by M. Mehl. B. Klein,
D. Papaconstantopoulos.
Expand All @@ -296,7 +316,7 @@ def _func(self, volume, params):
class BirchMurnaghan(EOSBase):
"""BirchMurnaghan EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""BirchMurnaghan equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (v0 / volume) ** (1 / 3)
Expand All @@ -306,7 +326,7 @@ def _func(self, volume, params):
class PourierTarantola(EOSBase):
"""Pourier-Tarantola EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""Pourier-Tarantola equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (volume / v0) ** (1 / 3)
Expand All @@ -317,7 +337,7 @@ def _func(self, volume, params):
class Vinet(EOSBase):
"""Vinet EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""Vinet equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (volume / v0) ** (1 / 3)
Expand All @@ -335,7 +355,7 @@ class PolynomialEOS(EOSBase):
def _func(self, volume, params):
return np.poly1d(list(params))(volume)

def fit(self, order):
def fit(self, order: int) -> None:
"""
Do polynomial fitting and set the parameters. Uses numpy polyfit.
Expand All @@ -345,7 +365,7 @@ def fit(self, order):
self.eos_params = np.polyfit(self.volumes, self.energies, order)
self._set_params()

def _set_params(self):
def _set_params(self) -> None:
"""
Use the fit polynomial to compute the parameter e0, b0, b1 and v0
and set to the _params attribute.
Expand All @@ -372,7 +392,7 @@ def _func(self, volume, params):
x = volume ** (-2 / 3.0)
return np.poly1d(list(params))(x)

def fit(self, order=3):
def fit(self, order: int = 3) -> None:
"""Overridden since this eos works with volume**(2/3) instead of volume."""
x = self.volumes ** (-2 / 3.0)
self.eos_params = np.polyfit(x, self.energies, order)
Expand Down Expand Up @@ -407,7 +427,12 @@ def _set_params(self):
class NumericalEOS(PolynomialEOS):
"""A numerical EOS."""

def fit(self, min_ndata_factor=3, max_poly_order_factor=5, min_poly_order=2):
def fit(
self,
min_ndata_factor: int = 3,
max_poly_order_factor: int = 5,
min_poly_order: int = 2,
) -> None:
"""Fit the input data to the 'numerical eos', the equation of state employed
in the quasiharmonic Debye model described in the paper:
10.1103/PhysRevB.90.174107.
Expand Down Expand Up @@ -539,7 +564,7 @@ class EOS:
eos_fit.plot()
"""

MODELS: ClassVar = {
MODELS: ClassVar[dict[str, Any]] = {
"murnaghan": Murnaghan,
"birch": Birch,
"birch_murnaghan": BirchMurnaghan,
Expand All @@ -549,7 +574,7 @@ class EOS:
"numerical_eos": NumericalEOS,
}

def __init__(self, eos_name="murnaghan"):
def __init__(self, eos_name: str = "murnaghan") -> None:
"""
Args:
eos_name (str): Type of EOS to fit.
Expand All @@ -562,7 +587,7 @@ def __init__(self, eos_name="murnaghan"):
self._eos_name = eos_name
self.model = self.MODELS[eos_name]

def fit(self, volumes, energies):
def fit(self, volumes: Sequence[float], energies: Sequence[float]) -> EOSBase:
"""Fit energies as function of volumes.
Args:
Expand Down
39 changes: 24 additions & 15 deletions src/pymatgen/analysis/prototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,27 @@ class AflowPrototypeMatcher:
https://doi.org/10.1016/j.commatsci.2017.01.017
"""

def __init__(self, initial_ltol=0.2, initial_stol=0.3, initial_angle_tol=5):
def __init__(
self,
initial_ltol: float = 0.2,
initial_stol: float = 0.3,
initial_angle_tol: float = 5,
) -> None:
"""
Tolerances as defined in StructureMatcher. Tolerances will be
gradually decreased until only a single match is found (if possible).
Args:
initial_ltol: fractional length tolerance
initial_stol: site tolerance
initial_angle_tol: angle tolerance
initial_ltol (float): fractional length tolerance.
initial_stol (float): site tolerance.
initial_angle_tol (float): angle tolerance.
"""
self.initial_ltol = initial_ltol
self.initial_stol = initial_stol
self.initial_angle_tol = initial_angle_tol

# Preprocess AFLOW prototypes
self._aflow_prototype_library = []
self._aflow_prototype_library: list[tuple[Structure, dict]] = []
for dct in AFLOW_PROTOTYPE_LIBRARY:
structure: Structure = dct["snl"].structure
reduced_structure = self._preprocess_structure(structure)
Expand All @@ -73,7 +78,11 @@ def __init__(self, initial_ltol=0.2, initial_stol=0.3, initial_angle_tol=5):
def _preprocess_structure(structure: Structure) -> Structure:
return structure.get_reduced_structure(reduction_algo="niggli").get_primitive_structure()

def _match_prototype(self, structure_matcher: StructureMatcher, reduced_structure: Structure):
def _match_prototype(
self,
structure_matcher: StructureMatcher,
reduced_structure: Structure,
) -> list[dict]:
tags = []
for aflow_reduced_structure, dct in self._aflow_prototype_library:
# Since both structures are already reduced, we can skip the structure reduction step
Expand All @@ -84,7 +93,7 @@ def _match_prototype(self, structure_matcher: StructureMatcher, reduced_structur
tags.append(dct)
return tags

def _match_single_prototype(self, structure: Structure):
def _match_single_prototype(self, structure: Structure) -> list[dict]:
sm = StructureMatcher(
ltol=self.initial_ltol,
stol=self.initial_stol,
Expand All @@ -102,23 +111,23 @@ def _match_single_prototype(self, structure: Structure):
break
return tags

def get_prototypes(self, structure: Structure) -> list | None:
def get_prototypes(self, structure: Structure) -> list[dict] | None:
"""Get prototype(s) structures for a given input structure. If you use this method in
your work, please cite the appropriate AFLOW publication:
Mehl, M. J., Hicks, D., Toher, C., Levy, O., Hanson, R. M., Hart, G., & Curtarolo,
S. (2017). The AFLOW library of crystallographic prototypes: part 1. Computational
Materials Science, 136, S1-S828. https://doi.org/10.1016/j.commatsci.2017.01.017
Mehl, M. J., Hicks, D., Toher, C., Levy, O., Hanson, R. M., Hart, G., & Curtarolo,
S. (2017). The AFLOW library of crystallographic prototypes: part 1. Computational
Materials Science, 136, S1-S828. https://doi.org/10.1016/j.commatsci.2017.01.017
Args:
structure: structure to match
structure (Structure): structure to match
Returns:
list | None: A list of dicts with keys 'snl' for the matched prototype and
'tags', a dict of tags ('mineral', 'strukturbericht' and 'aflow') of that
list[dict] | None: A list of dicts with keys "snl" for the matched prototype and
"tags", a dict of tags ("mineral", "strukturbericht" and "aflow") of that
prototype. This should be a list containing just a single entry, but it is
possible a material can match multiple prototypes.
"""
tags = self._match_single_prototype(structure)
tags: list[dict] = self._match_single_prototype(structure)

return tags or None
10 changes: 4 additions & 6 deletions src/pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4673,9 +4673,8 @@ def rotate_sites(
the structure in place.
Args:
indices (list): List of site indices on which to perform the
translation.
theta (float): Angle in radians
indices (list): Site indices on which to perform the rotation.
theta (float): Angle in radians.
axis (3x1 array): Rotation axis vector.
anchor (3x1 array): Point of rotation.
to_unit_cell (bool): Whether new sites are transformed to unit cell
Expand Down Expand Up @@ -5294,9 +5293,8 @@ def rotate_sites(
"""Rotate specific sites by some angle around vector at anchor.
Args:
indices (list): List of site indices on which to perform the
translation.
theta (float): Angle in radians
indices (list): Site indices on which to perform the rotation.
theta (float): Angle in radians.
axis (3x1 array): Rotation axis vector.
anchor (3x1 array): Point of rotation.
Expand Down
Loading

0 comments on commit 5aff054

Please sign in to comment.