Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abacus: add checks on pp and orb in construction of STRU #737

Merged
merged 8 commits into from
Oct 15, 2024
92 changes: 56 additions & 36 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,8 @@
link_file : bool, optional
Whether to link the pseudo potential files and orbital files in the STRU file.
If True, then only filename will be written in the STRU file, and make a soft link to the real file.
dest_dir : str, optional
The destination directory to make the soft link of the pseudo potential files and orbital files.
For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written.
ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment.
If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used.
Expand Down Expand Up @@ -682,6 +684,27 @@
out = "ATOMIC_SPECIES\n"
if pp_file is not None:
pp_file = ndarray2list(pp_file)
ppfiles = None
if isinstance(pp_file, (list, tuple)):
if len(pp_file) != len(data["atom_names"]):
raise RuntimeError(
"ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types"
)
ppfiles = pp_file
elif isinstance(pp_file, dict):
for iele in data["atom_names"]:
if iele not in pp_file:
raise RuntimeError(
f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}"
)
ppfiles = [
pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"]))
]
else:
raise RuntimeError(f"ERROR: invalid pp_file: {pp_file}")

Check warning on line 704 in dpdata/abacus/scf.py

View check run for this annotation

Codecov / codecov/patch

dpdata/abacus/scf.py#L704

Added line #L704 was not covered by tests
else:
ppfiles = None

Check warning on line 706 in dpdata/abacus/scf.py

View check run for this annotation

Codecov / codecov/patch

dpdata/abacus/scf.py#L706

Added line #L706 was not covered by tests

for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
Expand All @@ -690,57 +713,54 @@
out += f"{mass[iele]:.3f} "
else:
out += "1 "
if pp_file is not None:
if isinstance(pp_file, (list, tuple)):
ipp_file = pp_file[iele]
elif isinstance(pp_file, dict):
if data["atom_names"][iele] not in pp_file:
print(
f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}"
)
ipp_file = None
else:
ipp_file = pp_file[data["atom_names"][iele]]
else:
ipp_file = None
if ppfiles is not None:
ipp_file = ppfiles[iele]
if ipp_file is not None:
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)

out += "\n"
out += "\n"

# NUMERICAL_ORBITAL block
if numerical_orbital is not None:
assert len(numerical_orbital) == len(data["atom_names"])
numerical_orbital = ndarray2list(numerical_orbital)
out += "NUMERICAL_ORBITAL\n"
for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
if isinstance(numerical_orbital, (list, tuple)):
inum_orbital = numerical_orbital[iele]
elif isinstance(numerical_orbital, dict):
if data["atom_names"][iele] not in numerical_orbital:
print(
f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}"
orbfiles = []
if isinstance(numerical_orbital, (list, tuple)):
if len(numerical_orbital) != len(data["atom_names"]):
raise RuntimeError(
"ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types"
)
orbfiles = [
numerical_orbital[i]
for i in range(len(data["atom_names"]))
if data["atom_numbs"][i] != 0
]
elif isinstance(numerical_orbital, dict):
for iele in data["atom_names"]:
if iele not in numerical_orbital:
raise RuntimeError(
f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}"
)
inum_orbital = None
else:
inum_orbital = numerical_orbital[data["atom_names"][iele]]
orbfiles = [
numerical_orbital[data["atom_names"][i]]
for i in range(len(data["atom_names"]))
if data["atom_numbs"][i] != 0
]
else:
raise RuntimeError(f"ERROR: invalid numerical_orbital: {numerical_orbital}")

Check warning on line 754 in dpdata/abacus/scf.py

View check run for this annotation

Codecov / codecov/patch

dpdata/abacus/scf.py#L754

Added line #L754 was not covered by tests

out += "NUMERICAL_ORBITAL\n"
for iorb in orbfiles:
if not link_file:
out += iorb
else:
inum_orbital = None
if inum_orbital is not None:
if not link_file:
out += inum_orbital
else:
out += os.path.basename(inum_orbital.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, inum_orbital)
out += os.path.basename(iorb.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, iorb)
out += "\n"
out += "\n"

Expand Down
53 changes: 53 additions & 0 deletions tests/test_abacus_stru_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,59 @@ def test_dumpStruLinkFile(self):
if os.path.isdir("abacus.scf/tmp"):
shutil.rmtree("abacus.scf/tmp")

def test_dump_stru_pporb_mismatch(self):
(
self.assertRaises(
RuntimeError,
self.system_ch4.to,
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "O": "O.upf"},
numerical_orbital={"C": "C.orb", "H": "H.orb"},
),
"pp_file is a dict and lack of pp for H",
)

(
self.assertRaises(
RuntimeError,
self.system_ch4.to,
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf"],
numerical_orbital={"C": "C.orb", "H": "H.orb"},
),
"pp_file is a list and lack of pp for H",
)

(
self.assertRaises(
RuntimeError,
self.system_ch4.to,
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "H": "H.upf"},
numerical_orbital={"C": "C.orb", "O": "O.orb"},
),
"numerical_orbital is a dict and lack of orbital for H",
)

(
self.assertRaises(
RuntimeError,
self.system_ch4.to,
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf", "H.upf"],
numerical_orbital=["C.orb"],
),
"numerical_orbital is a list and lack of orbital for H",
)

def test_dump_spinconstrain(self):
self.system_ch4.to(
"stru",
Expand Down