Skip to content

Commit 1f1ccbf

Browse files
author
Keian Noori
committed
Merge branch 'master' into update_vasp_outputs
2 parents b8819b7 + 55f70b2 commit 1f1ccbf

13 files changed

+1181
-363
lines changed

src/pymatgen/core/periodic_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __eq__(self, other: object) -> bool:
236236

237237
def __hash__(self) -> int:
238238
# multiply Z by 1000 to avoid hash collisions of element N with isotopes of elements N+/-1,2,3...
239-
return self.Z * 1000 + self.A if self._is_named_isotope else self.Z
239+
return self.Z * 1000 + self.A if self._is_named_isotope else self.Z * 137 * 100
240240

241241
def __repr__(self) -> str:
242242
return f"Element {self.symbol}"
@@ -1605,7 +1605,7 @@ def get_el_sp(obj: SpeciesLike) -> Element | Species | DummySpecies:
16051605
pass
16061606

16071607

1608-
@functools.lru_cache
1608+
@functools.lru_cache(maxsize=1024)
16091609
def get_el_sp(obj: int | SpeciesLike) -> Element | Species | DummySpecies:
16101610
"""Utility function to get an Element, Species or DummySpecies from any input.
16111611

src/pymatgen/io/jdftx/_output_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import numpy as np
1717

18+
from pymatgen.electronic_structure.core import Orbital
19+
1820
if TYPE_CHECKING:
1921
from collections.abc import Callable
2022

@@ -476,6 +478,34 @@ def get_proj_tju_from_file(bandfile_filepath: Path | str) -> NDArray[np.float32
476478
return _parse_bandfile_complex(bandfile_filepath) if is_complex else _parse_bandfile_normalized(bandfile_filepath)
477479

478480

481+
def _parse_kptsfrom_bandprojections_file(bandfile_filepath: str | Path) -> tuple[list[float], list[NDArray]]:
482+
"""Parse kpts from bandprojections file.
483+
484+
Parse kpts from bandprojections file.
485+
486+
Args:
487+
bandfile_filepath (Path | str): Path to bandprojections file.
488+
489+
Returns:
490+
tuple[list[float], list[np.ndarray[float]]]: Tuple of k-point weights and k-points
491+
"""
492+
wk_list: list[float] = []
493+
k_points_list: list[NDArray] = []
494+
kpt_lines = []
495+
with open(bandfile_filepath) as f:
496+
for line in f:
497+
if line.startswith("#") and ";" in line:
498+
_line = line.split(";")[0].lstrip("#")
499+
kpt_lines.append(_line)
500+
for line in kpt_lines:
501+
k_points = line.split("[")[1].split("]")[0].strip().split()
502+
_k_points_floats: list[float] = [float(v) for v in k_points]
503+
k_points_list.append(np.array(_k_points_floats))
504+
wk = float(line.split("]")[1].strip().split()[0])
505+
wk_list.append(wk)
506+
return wk_list, k_points_list
507+
508+
479509
def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
480510
"""Determine if bandprojections file is complex.
481511
@@ -507,6 +537,64 @@ def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
507537
["dxy", "dyz", "dz2", "dxz", "dx2-y2"],
508538
["fy(3x2-y2)", "fxyz", "fyz2", "fz3", "fxz2", "fz(x2-y2)", "fx(x2-3y2)"],
509539
]
540+
orb_ref_to_o_dict = {
541+
"s": int(Orbital.s),
542+
"py": int(Orbital.py),
543+
"pz": int(Orbital.pz),
544+
"px": int(Orbital.px),
545+
"dxy": int(Orbital.dxy),
546+
"dyz": int(Orbital.dyz),
547+
"dz2": int(Orbital.dz2),
548+
"dxz": int(Orbital.dxz),
549+
"dx2-y2": int(Orbital.dx2),
550+
# Keep the f-orbitals arbitrary-ish until they get designated names in pymatgen.
551+
orb_ref_list[-1][0]: int(Orbital.f_3),
552+
orb_ref_list[-1][1]: int(Orbital.f_2),
553+
orb_ref_list[-1][2]: int(Orbital.f_1),
554+
orb_ref_list[-1][3]: int(Orbital.f0),
555+
orb_ref_list[-1][4]: int(Orbital.f1),
556+
orb_ref_list[-1][5]: int(Orbital.f2),
557+
}
558+
559+
560+
def _get_atom_orb_labels_map_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
561+
"""
562+
Return a dictionary mapping each atom symbol to pymatgen-compatible orbital projection string representations.
563+
564+
Identical to _get_atom_orb_labels_ref_dict, but doesn't include the numbers in the labels.
565+
566+
567+
568+
Args:
569+
bandfile_filepath (str | Path): The path to the bandfile.
570+
571+
Returns:
572+
dict[str, list[str]]: A dictionary mapping each atom symbol to all atomic orbital projection string
573+
representations.
574+
"""
575+
bandfile = read_file(bandfile_filepath)
576+
labels_dict: dict[str, list[str]] = {}
577+
578+
for i, line in enumerate(bandfile):
579+
if i > 1:
580+
if "#" in line:
581+
break
582+
lsplit = line.strip().split()
583+
sym = lsplit[0]
584+
labels_dict[sym] = []
585+
lmax = int(lsplit[3])
586+
# Would prefer to use "l" rather than "L" here (as uppercase "L" means something else entirely) but
587+
# pr*-c*mm*t thinks "l" is an ambiguous variable name.
588+
for L in range(lmax + 1):
589+
mls = orb_ref_list[L]
590+
nshells = int(lsplit[4 + L])
591+
for _n in range(nshells):
592+
if nshells > 1:
593+
for ml in mls:
594+
labels_dict[sym].append(f"{ml}")
595+
else:
596+
labels_dict[sym] += mls
597+
return labels_dict
510598

511599

512600
def _get_atom_orb_labels_ref_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
@@ -623,6 +711,28 @@ def _get_orb_label(ion: str, idx: int, orb: str) -> str:
623711
return f"{ion}#{idx + 1}({orb})"
624712

625713

714+
def _get_u_to_oa_map(bandfile_filepath: Path) -> list[tuple[int, int]]:
715+
"""
716+
Return a list, where the u'th element is a tuple of the atomic orbital index and the ion index.
717+
718+
Args:
719+
bandfile_filepath (str | Path): The path to the bandfile.
720+
721+
Returns:
722+
list[tuple[int, int]]: A list, where the u'th element is a tuple of the atomic orbital index and the ion index.
723+
"""
724+
map_labels_dict = _get_atom_orb_labels_map_dict(bandfile_filepath)
725+
atom_count_list = _get_atom_count_list(bandfile_filepath)
726+
u_to_oa_map = []
727+
a = 0
728+
for ion, ion_count in atom_count_list:
729+
for _i in range(ion_count):
730+
for orb in map_labels_dict[ion]:
731+
u_to_oa_map.append((orb_ref_to_o_dict[orb], a))
732+
a += 1
733+
return u_to_oa_map
734+
735+
626736
def _get_orb_label_list(bandfile_filepath: Path) -> tuple[str, ...]:
627737
"""
628738
Return a tuple of all atomic orbital projection string representations.

src/pymatgen/io/jdftx/generic_tags.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,56 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
7070
tuple[str, bool, Any]: The tag, whether the value is of the correct type, and the possibly fixed value.
7171
"""
7272

73+
def is_equal_to(self, val1: Any | list[Any], obj2: AbstractTag, val2: Any | list[Any]) -> bool:
74+
"""Check if the two values are equal.
75+
76+
Args:
77+
val1 (Any): The value of this tag object.
78+
obj2 (AbstractTag): The other tag object.
79+
val2 (Any): The value of the other tag object.
80+
81+
Returns:
82+
bool: True if the two tag object/value pairs are equal, False otherwise.
83+
"""
84+
if self.can_repeat:
85+
if not obj2.can_repeat:
86+
return False
87+
val1 = val1 if isinstance(val1, list) else [val1]
88+
val2 = val2 if isinstance(val2, list) else [val2]
89+
if len(val1) != len(val2):
90+
return False
91+
return all(True in [self._is_equal_to(v1, obj2, v2) for v2 in val2] for v1 in val1)
92+
return self._is_equal_to(val1, obj2, val2)
93+
94+
@abstractmethod
95+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
96+
"""Check if the two values are equal.
97+
98+
Used to check if the two values are equal. Assumes val1 and val2 are single elements.
99+
100+
Args:
101+
val1 (Any): The value of this tag object.
102+
obj2 (AbstractTag): The other tag object.
103+
val2 (Any): The value of the other tag object.
104+
105+
Returns:
106+
bool: True if the two tag object/value pairs are equal, False otherwise.
107+
"""
108+
109+
def _is_same_tagtype(
110+
self,
111+
obj2: AbstractTag,
112+
) -> bool:
113+
"""Check if the two values are equal.
114+
115+
Args:
116+
obj2 (AbstractTag): The other tag object.
117+
118+
Returns:
119+
bool: True if the two tag object/value pairs are equal, False otherwise.
120+
"""
121+
return isinstance(self, type(obj2))
122+
73123
def _validate_value_type(
74124
self, type_check: type, tag: str, value: Any, try_auto_type_fix: bool = False
75125
) -> tuple[str, bool, Any]:
@@ -258,6 +308,19 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
258308
"""
259309
return self._validate_value_type(bool, tag, value, try_auto_type_fix=try_auto_type_fix)
260310

311+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
312+
"""Check if the two values are equal.
313+
314+
Args:
315+
val1 (Any): The value of this tag object.
316+
obj2 (AbstractTag): The other tag object.
317+
val2 (Any): The value of the other tag object.
318+
319+
Returns:
320+
bool: True if the two tag object/value pairs are equal, False otherwise.
321+
"""
322+
return self._is_same_tagtype(obj2) and val1 == val2
323+
261324
def raise_value_error(self, tag: str, value: str) -> None:
262325
"""Raise a ValueError for the value string.
263326
@@ -335,6 +398,23 @@ def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = Fa
335398
"""
336399
return self._validate_value_type(str, tag, value, try_auto_type_fix=try_auto_type_fix)
337400

401+
def _is_equal_to(self, val1: Any, obj2: AbstractTag, val2: Any) -> bool:
402+
"""Check if the two values are equal.
403+
404+
Args:
405+
val1 (Any): The value of this tag object.
406+
obj2 (AbstractTag): The other tag object.
407+
val2 (Any): The value of the other tag object.
408+
409+
Returns:
410+
bool: True if the two tag object/value pairs are equal, False otherwise.
411+
"""
412+
if self._is_same_tagtype(obj2):
413+
if not all(isinstance(x, str) for x in (val1, val2)):
414+
raise ValueError("Both values must be strings for StrTag comparison")
415+
return val1.strip() == val2.strip()
416+
return False
417+
338418
def read(self, tag: str, value: str) -> str:
339419
"""Read the value string for this tag.
340420
@@ -379,6 +459,8 @@ class AbstractNumericTag(AbstractTag):
379459
ub: float | None = None # upper bound
380460
lb_incl: bool = True # lower bound inclusive
381461
ub_incl: bool = True # upper bound inclusive
462+
eq_atol: float = 1.0e-8 # absolute tolerance for equality check
463+
eq_rtol: float = 1.0e-5 # relative tolerance for equality check
382464

383465
def val_is_within_bounds(self, value: float) -> bool:
384466
"""Check if the value is within the bounds.
@@ -425,6 +507,22 @@ def validate_value_bounds(
425507
return False, self.get_invalid_value_error_str(tag, value)
426508
return True, ""
427509

510+
def _is_equal_to(self, val1, obj2, val2):
511+
"""Check if the two values are equal.
512+
513+
Used to check if the two values are equal. Doesn't need to be redefined for IntTag and FloatTag.
514+
515+
Args:
516+
val1 (Any): The value of this tag object.
517+
obj2 (AbstractTag): The other tag object.
518+
val2 (Any): The value of the other tag object.
519+
rtol (float, optional): Relative tolerance. Defaults to 1.e-5.
520+
atol (float, optional): Absolute tolerance. Defaults to 1.e-8.
521+
Returns:
522+
bool: True if the two tag object/value pairs are equal, False otherwise.
523+
"""
524+
return self._is_same_tagtype(obj2) and np.isclose(val1, val2, rtol=self.eq_rtol, atol=self.eq_atol)
525+
428526

429527
@dataclass
430528
class IntTag(AbstractNumericTag):
@@ -620,6 +718,10 @@ def get_token_len(self) -> int:
620718
"""
621719
return self._get_token_len()
622720

721+
def _is_equal_to(self, val1, obj2, val2):
722+
return True # TODO: We still need to actually implement initmagmom as a multi-format tag
723+
# raise NotImplementedError("equality not yet implemented for InitMagMomTag")
724+
623725

624726
@dataclass
625727
class TagContainer(AbstractTag):
@@ -1013,6 +1115,28 @@ def get_dict_representation(self, tag: str, value: list) -> dict | list[dict]:
10131115
list_value = self._make_str_for_dict(tag, value)
10141116
return self.read(tag, list_value)
10151117

1118+
def _is_equal_to(self, val1, obj2, val2):
1119+
"""Check if the two values are equal.
1120+
1121+
Return False if (checked in following order)
1122+
- obj2 is not a TagContainer
1123+
- all of val1's subtags are not in val2
1124+
- val1 and val2 are not the same length (different number of subtags)
1125+
- at least one subtag in val1 is not equal to the corresponding subtag in val2
1126+
"""
1127+
if self._is_same_tagtype(obj2):
1128+
if isinstance(val1, dict) and isinstance(val2, dict):
1129+
if all(subtag in val2 for subtag in val1) and (len(list(val1.keys())) == len(list(val2.keys()))):
1130+
for subtag, subtag_type in self.subtags.items():
1131+
if (subtag in val1) and (
1132+
not subtag_type.is_equal_to(val1[subtag], obj2.subtags[subtag], val2[subtag])
1133+
):
1134+
return False
1135+
return True
1136+
return False
1137+
raise ValueError("Values must be in dictionary format for TagContainer comparison")
1138+
return False
1139+
10161140

10171141
# TODO: Write StructureDefferedTagContainer back in (commented out code block removed
10181142
# on 11/4/24) and make usable for tags like initial-magnetic-moments
@@ -1162,6 +1286,9 @@ def get_token_len(self) -> int:
11621286
"""
11631287
raise NotImplementedError("This method is not supposed to be called directly on MultiformatTag objects!")
11641288

1289+
def _is_equal_to(self, val1, obj2, val2):
1290+
raise NotImplementedError("This method is not supposed to be called directly on MultiformatTag objects!")
1291+
11651292

11661293
@dataclass
11671294
class BoolTagContainer(TagContainer):

0 commit comments

Comments
 (0)