Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abacus: add checks on pp and orb in construction of STRU #737

Merged
merged 8 commits into from
Oct 15, 2024
90 changes: 44 additions & 46 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@
def make_unlabeled_stru(
data,
frame_idx,
pp_file=None,
pp_file,
numerical_orbital=None,
numerical_descriptor=None,
mass=None,
Expand All @@ -601,7 +601,7 @@
System data
frame_idx : int
The index of the frame to dump
pp_file : list of string or dict, optional
pp_file : list of string or dict
List of pseudo potential files, or a dictionary of pseudo potential files for each atomnames
numerical_orbital : list of string or dict, optional
List of orbital files, or a dictionary of orbital files for each atomnames
Expand All @@ -628,6 +628,8 @@
link_file : bool, optional
Whether to link the pseudo potential files and orbital files in the STRU file.
If True, then only filename will be written in the STRU file, and make a soft link to the real file.
dest_dir : str, optional
The destination directory to make the soft link of the pseudo potential files and orbital files.
For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written.
ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment.
If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used.
Expand Down Expand Up @@ -655,6 +657,23 @@
else:
return i

def process_file_input(file_input, atom_names, input_name):
# For pp_file and numerical_orbital, process the file input, and return a list of file names
# file_input can be a list of file names, or a dictionary of file names for each atom names
if isinstance(file_input, (list, tuple)):
if len(file_input) != len(atom_names):
raise ValueError(
f"{input_name} length is not equal to the number of atom types"
)
return file_input
elif isinstance(file_input, dict):
for element in atom_names:
if element not in file_input:
raise KeyError(f"{input_name} does not contain {element}")
return [file_input[element] for element in atom_names]
else:
raise ValueError(f"Invalid {input_name}: {file_input}")

Check warning on line 675 in dpdata/abacus/scf.py

View check run for this annotation

Codecov / codecov/patch

dpdata/abacus/scf.py#L675

Added line #L675 was not covered by tests

if link_file and dest_dir is None:
print(
"WARNING: make_unlabeled_stru: link_file is True, but dest_dir is None. Will write the filename to STRU but not making soft link."
Expand All @@ -680,8 +699,8 @@

# ATOMIC_SPECIES block
out = "ATOMIC_SPECIES\n"
if pp_file is not None:
pp_file = ndarray2list(pp_file)
ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file")

for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
Expand All @@ -690,57 +709,36 @@
out += f"{mass[iele]:.3f} "
else:
out += "1 "
if pp_file is not None:
if isinstance(pp_file, (list, tuple)):
ipp_file = pp_file[iele]
elif isinstance(pp_file, dict):
if data["atom_names"][iele] not in pp_file:
print(
f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}"
)
ipp_file = None
else:
ipp_file = pp_file[data["atom_names"][iele]]
else:
ipp_file = None
if ipp_file is not None:
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)

ipp_file = ppfiles[iele]
if not link_file:
out += ipp_file
else:
out += os.path.basename(ipp_file.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, ipp_file)
out += "\n"
out += "\n"

# NUMERICAL_ORBITAL block
if numerical_orbital is not None:
assert len(numerical_orbital) == len(data["atom_names"])
numerical_orbital = ndarray2list(numerical_orbital)
orbfiles = process_file_input(
numerical_orbital, data["atom_names"], "numerical_orbital"
)
orbfiles = [
orbfiles[i]
for i in range(len(data["atom_names"]))
if data["atom_numbs"][i] != 0
]
out += "NUMERICAL_ORBITAL\n"
for iele in range(len(data["atom_names"])):
if data["atom_numbs"][iele] == 0:
continue
if isinstance(numerical_orbital, (list, tuple)):
inum_orbital = numerical_orbital[iele]
elif isinstance(numerical_orbital, dict):
if data["atom_names"][iele] not in numerical_orbital:
print(
f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}"
)
inum_orbital = None
else:
inum_orbital = numerical_orbital[data["atom_names"][iele]]
for iorb in orbfiles:
if not link_file:
out += iorb
else:
inum_orbital = None
if inum_orbital is not None:
if not link_file:
out += inum_orbital
else:
out += os.path.basename(inum_orbital.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, inum_orbital)
out += os.path.basename(iorb.rstrip("/"))
if dest_dir is not None:
_link_file(dest_dir, iorb)
out += "\n"
out += "\n"

Expand Down
43 changes: 43 additions & 0 deletions tests/test_abacus_stru_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@ def test_dumpStruLinkFile(self):
if os.path.isdir("abacus.scf/tmp"):
shutil.rmtree("abacus.scf/tmp")

def test_dump_stru_pporb_mismatch(self):
with self.assertRaises(KeyError, msg="pp_file is a dict and lack of pp for H"):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "O": "O.upf"},
numerical_orbital={"C": "C.orb", "H": "H.orb"},
)

with self.assertRaises(
ValueError, msg="pp_file is a list and lack of pp for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf"],
numerical_orbital={"C": "C.orb", "H": "H.orb"},
)

with self.assertRaises(
KeyError, msg="numerical_orbital is a dict and lack of orbital for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file={"C": "C.upf", "H": "H.upf"},
numerical_orbital={"C": "C.orb", "O": "O.orb"},
)

with self.assertRaises(
ValueError, msg="numerical_orbital is a list and lack of orbital for H"
):
self.system_ch4.to(
"stru",
"STRU_tmp",
mass=[12, 1],
pp_file=["C.upf", "H.upf"],
numerical_orbital=["C.orb"],
)

def test_dump_spinconstrain(self):
self.system_ch4.to(
"stru",
Expand Down