Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from attr to attrs #318

Merged
merged 4 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from numbers import Integral
from typing import Union

import attr
import attrs
import numpy as np
from numpy.typing import NDArray

from .attrutils import validate_shape

Expand Down Expand Up @@ -100,7 +101,7 @@ def angmom_its(angmom: Union[int, list[int]]) -> Union[str, list[str]]:
return ANGMOM_CHARS[angmom]


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class Shell:
"""A shell in a molecular basis representing (generalized) contractions with the same exponents.

Expand All @@ -126,11 +127,11 @@ class Shell:

"""

icenter: int
angmoms: list[int] = attr.ib(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attr.ib(validator=validate_shape(("coeffs", 1)))
exponents: np.ndarray = attr.ib(validator=validate_shape(("coeffs", 0)))
coeffs: np.ndarray = attr.ib(validator=validate_shape(("exponents", 0), ("kinds", 0)))
icenter: int = attrs.field()
angmoms: list[int] = attrs.field(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attrs.field(validator=validate_shape(("coeffs", 1)))
exponents: NDArray = attrs.field(validator=validate_shape(("coeffs", 0)))
coeffs: NDArray = attrs.field(validator=validate_shape(("exponents", 0), ("kinds", 0)))

@property
def nbasis(self) -> int:
Expand All @@ -156,7 +157,7 @@ def ncon(self) -> int:
return len(self.angmoms)


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class MolecularBasis:
"""A complete molecular orbital or density basis set.

Expand Down Expand Up @@ -205,9 +206,9 @@ class MolecularBasis:

"""

shells: list[Shell]
conventions: dict[str, str]
primitive_normalization: str
shells: list[Shell] = attrs.field()
conventions: dict[str, str] = attrs.field()
primitive_normalization: str = attrs.field()

@property
def nbasis(self) -> int:
Expand All @@ -222,12 +223,12 @@ def get_segmented(self):
shells.append(
Shell(shell.icenter, [angmom], [kind], shell.exponents, coeffs.reshape(-1, 1))
)
return attr.evolve(self, shells=shells)
return attrs.evolve(self, shells=shells)


def convert_convention_shell(
conv1: list[str], conv2: list[str], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.

The transformation from convention 1 to convention 2 can be done applying
Expand Down Expand Up @@ -289,7 +290,7 @@ def convert_convention_shell(

def convert_conventions(
molbasis: MolecularBasis, new_conventions: dict[str, list[str]], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.

The transformation from molbasis.convention to the new convention can be done
Expand Down Expand Up @@ -339,7 +340,7 @@ def convert_conventions(
return np.array(permutation), np.array(signs)


def iter_cart_alphabet(n: int) -> np.ndarray:
def iter_cart_alphabet(n: int) -> NDArray:
"""Loop over powers of Cartesian basis functions in alphabetical order.

See https://theochem.github.io/horton/2.1.1/tech_ref_gaussian_basis.html
Expand Down
3 changes: 2 additions & 1 deletion iodata/formats/chgcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..periodic import sym2num
Expand All @@ -37,7 +38,7 @@
PATTERNS = ["CHGCAR*", "AECCAR*"]


def _load_vasp_header(lit: LineIterator) -> tuple[str, np.ndarray, np.ndarray, np.ndarray]:
def _load_vasp_header(lit: LineIterator) -> tuple[str, NDArray, NDArray, NDArray]:
"""Load the cell and atoms from a VASP file format.

Parameters
Expand Down
15 changes: 7 additions & 8 deletions iodata/formats/cp2klog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorialk

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, angmom_sti
Expand All @@ -42,9 +43,7 @@
}


def _get_cp2k_norm_corrections(
ell: int, alphas: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
def _get_cp2k_norm_corrections(ell: int, alphas: Union[float, NDArray]) -> Union[float, NDArray]:
"""Compute the corrections for the normalization of the basis functions.

This correction is needed because the CP2K atom code works with a different
Expand Down Expand Up @@ -236,7 +235,7 @@ def _read_cp2k_occupations_energies(

def _read_cp2k_orbital_coeffs(
lit: LineIterator, oe: list[tuple[int, int, float, float]]
) -> dict[tuple[int, int], np.ndarray]:
) -> dict[tuple[int, int], NDArray]:
"""Read the expansion coefficients of the orbital from an open CP2K ATOM output.

Parameters
Expand Down Expand Up @@ -294,11 +293,11 @@ def _get_norb_nel(oe: list[tuple[int, int, float, float]]) -> tuple[int, float]:


def _fill_orbitals(
orb_coeffs: np.ndarray,
orb_energies: np.ndarray,
orb_occupations: np.ndarray,
orb_coeffs: NDArray,
orb_energies: NDArray,
orb_occupations: NDArray,
oe: list[tuple[int, int, float, float]],
coeffs: dict[tuple[int, int], np.ndarray],
coeffs: dict[tuple[int, int], NDArray],
obasis: MolecularBasis,
restricted: bool,
):
Expand Down
19 changes: 10 additions & 9 deletions iodata/formats/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
Expand All @@ -42,7 +43,7 @@

def _read_cube_header(
lit: LineIterator,
) -> tuple[str, np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray], np.ndarray]:
) -> tuple[str, NDArray, NDArray, NDArray, dict[str, NDArray], NDArray]:
"""Load header data from a CUBE file object.

Parameters
Expand All @@ -62,7 +63,7 @@ def _read_cube_header(
# skip the second line
next(lit)

def read_grid_line(line: str) -> tuple[int, np.ndarray]:
def read_grid_line(line: str) -> tuple[int, NDArray]:
"""Read a grid line from the cube file."""
words = line.split()
return (
Expand All @@ -83,7 +84,7 @@ def read_grid_line(line: str) -> tuple[int, np.ndarray]:
cellvecs = axes * shape.reshape(-1, 1)
cube = {"origin": origin, "axes": axes, "shape": shape}

def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
def read_atom_line(line: str) -> tuple[int, float, NDArray]:
"""Read an atomic number and coordinate from the cube file."""
words = line.split()
return (
Expand All @@ -106,7 +107,7 @@ def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
return title, atcoords, atnums, cellvecs, cube, atcorenums


def _read_cube_data(lit: LineIterator, cube: dict[str, np.ndarray]):
def _read_cube_data(lit: LineIterator, cube: dict[str, NDArray]):
"""Load cube data from a CUBE file object.

Parameters
Expand Down Expand Up @@ -150,10 +151,10 @@ def load_one(lit: LineIterator) -> dict:
def _write_cube_header(
f: TextIO,
title: str,
atcoords: np.ndarray,
atnums: np.ndarray,
cube: dict[str, np.ndarray],
atcorenums: np.ndarray,
atcoords: NDArray,
atnums: NDArray,
cube: dict[str, NDArray],
atcorenums: NDArray,
):
print(title, file=f)
print("OUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z", file=f)
Expand All @@ -169,7 +170,7 @@ def _write_cube_header(
print(f"{atnums[i]:5d} {q: 11.6f} {x: 11.6f} {y: 11.6f} {z: 11.6f}", file=f)


def _write_cube_data(f: TextIO, cube_data: np.ndarray, block_size: int):
def _write_cube_data(f: TextIO, cube_data: NDArray, block_size: int):
counter = 0
for value in cube_data.flat:
f.write(f" {value: 12.5E}")
Expand Down
7 changes: 4 additions & 3 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Optional, TextIO

import numpy as np
from numpy.typing import NDArray

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, convert_conventions
from ..docstrings import document_dump_one, document_load_many, document_load_one
Expand Down Expand Up @@ -473,7 +474,7 @@ def _load_dm(label: str, fchk: dict, result: dict, key: str):
result[key] = _triangle_to_dense(fchk[label])


def _triangle_to_dense(triangle: np.ndarray) -> np.ndarray:
def _triangle_to_dense(triangle: NDArray) -> NDArray:
"""Convert a symmetric matrix in triangular storage to a dense square matrix.

Parameters
Expand Down Expand Up @@ -512,7 +513,7 @@ def _dump_real_scalars(name: str, val: float, f: TextIO):
print(f"{name:40} R {float(val): 16.8E}", file=f)


def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_integer_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of integers."""
nval = val.size
if nval != 0:
Expand All @@ -527,7 +528,7 @@ def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
k = 0


def _dump_real_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_real_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of float."""
nval = val.size
if nval != 0:
Expand Down
13 changes: 7 additions & 6 deletions iodata/formats/gamess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""GAMESS punch file format."""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, angstrom
Expand All @@ -29,7 +30,7 @@
PATTERNS = ["*.dat"]


def _read_data(lit: LineIterator) -> tuple:
def _read_data(lit: LineIterator) -> tuple[str, str, list[str]]:
"""Extract ``title``, ``symmetry`` and ``symbols`` from the punch file."""
title = next(lit).strip()
symmetry = next(lit).split()[0]
Expand All @@ -46,7 +47,7 @@ def _read_data(lit: LineIterator) -> tuple:
return title, symmetry, symbols


def _read_coordinates(lit: LineIterator, result: dict) -> tuple:
def _read_coordinates(lit: LineIterator, result: dict[str]) -> tuple[NDArray, NDArray]:
"""Extract ``numbers`` and ``coordinates`` from the punch file."""
for _ in range(2):
next(lit)
Expand All @@ -67,7 +68,7 @@ def _read_coordinates(lit: LineIterator, result: dict) -> tuple:
return numbers, coordinates


def _read_energy(lit: LineIterator, result: dict) -> tuple:
def _read_energy(lit: LineIterator, result: dict[str]) -> tuple[float, NDArray]:
"""Extract ``energy`` and ``gradient`` from the punch file."""
energy = float(next(lit).split()[1])
natom = len(result["symbols"])
Expand All @@ -81,7 +82,7 @@ def _read_energy(lit: LineIterator, result: dict) -> tuple:
return energy, gradient


def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
def _read_hessian(lit: LineIterator, result: dict[str]) -> NDArray:
"""Extract ``hessian`` from the punch file."""
# check that $HESS is not already parsed
if "athessian" in result:
Expand All @@ -102,7 +103,7 @@ def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
return hessian


def _read_masses(lit: LineIterator, result: dict) -> np.ndarray:
def _read_masses(lit: LineIterator, result: dict[str]) -> NDArray:
"""Extract ``masses`` from the punch file."""
natom = len(result["symbols"])
masses = np.zeros(natom, float)
Expand All @@ -119,7 +120,7 @@ def _read_masses(lit: LineIterator, result: dict) -> np.ndarray:
"PUNCH",
["title", "energy", "grot", "atgradient", "athessian", "atmasses", "atnums", "atcoords"],
)
def load_one(lit: LineIterator) -> dict:
def load_one(lit: LineIterator) -> dict[str]:
"""Do not edit this docstring. It will be overwritten."""
result = {}
while True:
Expand Down
5 changes: 3 additions & 2 deletions iodata/formats/gaussianlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, set_four_index_element
Expand Down Expand Up @@ -73,7 +74,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a two-index operator from a GAUSSIAN LOG file format.

Parameters
Expand Down Expand Up @@ -106,7 +107,7 @@ def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
return result


def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a four-index operator from a GAUSSIAN LOG file.

Parameters
Expand Down
7 changes: 3 additions & 4 deletions iodata/formats/mol2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import (
document_dump_many,
Expand Down Expand Up @@ -83,9 +84,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_helper_atoms(
lit: LineIterator, natoms: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple]:
def _load_helper_atoms(lit: LineIterator, natoms: int) -> tuple[NDArray, NDArray, NDArray, tuple]:
"""Load element numbers, coordinates and atomic charges."""
atnums = np.empty(natoms)
atcoords = np.empty((natoms, 3))
Expand All @@ -112,7 +111,7 @@ def _load_helper_atoms(
return atnums, atcoords, atchgs, attypes


def _load_helper_bonds(lit: LineIterator, nbonds: int) -> tuple[np.ndarray]:
def _load_helper_bonds(lit: LineIterator, nbonds: int) -> NDArray:
"""Load bond information.

Each line in a bond definition has the following structure
Expand Down
Loading