Skip to content

Commit 270b307

Browse files
1041176461jiyuangQianruipku
authored
Fix: dos plot for nspin=2 (#3928)
* Use template to reconstruct parse_expression * Feature: output R matrix at each MD step * Modify'matrix_HS' to 'matrix' for R matrix output * Merge branches 'develop' and 'develop' of https://github.com/1041176461/abacus-develop into develop * Fix: modify index in parse_expression * Fix: add regfree for parse_expression * Doc: update phonopy doc * Doc: update phonopy doc * fix tdos plot for nspin=2 * optimize dosplot for nspin=2 * fix legend for dosplot --------- Co-authored-by: jiyuang <[email protected]> Co-authored-by: Qianrui <[email protected]>
1 parent ee0dc6f commit 270b307

File tree

1 file changed

+21
-13
lines changed
  • tools/plot-tools/abacus_plot

1 file changed

+21
-13
lines changed

tools/plot-tools/abacus_plot/dos.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
class DOS:
2121
"""Parse DOS data"""
2222

23-
def __init__(self) -> None:
24-
self.nspin = 1
23+
def __init__(self, nspin) -> None:
24+
self.nspin = nspin
2525
if self.nspin in [1, 4]:
2626
self._nsplit = 1
2727
elif self.nspin == 2:
@@ -89,10 +89,11 @@ def bandgap(cls, vb: namedtuple, cb: namedtuple):
8989
return gap
9090

9191

92-
class DOSPlot:
92+
class :
9393
"""Plot density of state(DOS)"""
9494

95-
def __init__(self, fig: Figure = None, ax: axes.Axes = None, **kwargs) -> None:
95+
def __init__(self, fig: Figure = None, ax: axes.Axes = None, nspin: int = 1, **kwargs) -> None:
96+
self.nspin = nspin
9697
self.fig = fig
9798
self.ax = ax
9899
self._lw = kwargs.pop('lw', 2)
@@ -145,19 +146,26 @@ def _set_figure(self, energy_range: Sequence = [], dos_range: Sequence = [], not
145146
else:
146147
self.ax.axvline(0, linestyle="--", c='b', lw=1.0)
147148

149+
if self.nspin == 2:
150+
self.ax.axhline(0, linestyle="--", c='gray', lw=1.0)
151+
152+
handles, labels = self.ax.get_legend_handles_labels()
153+
by_label = OrderedDict(zip(labels, handles))
148154
if "legend_prop" in self.plot_params.keys():
149-
self.ax.legend(prop=self.plot_params["legend_prop"])
155+
self.ax.legend(by_label.values(), by_label.keys(),
156+
prop=self.plot_params["legend_prop"])
150157
else:
151-
self.ax.legend(prop={'size': 15})
158+
self.ax.legend(by_label.values(),
159+
by_label.keys(), prop={'size': 15})
152160

153161

154162
class TDOS(DOS):
155163
"""Parse total DOS data"""
156164

157165
def __init__(self, tdosfile: PathLike = None) -> None:
158-
super().__init__()
159166
self.tdosfile = tdosfile
160167
self._read()
168+
super().__init__(self.nspin)
161169

162170
def _read(self) -> tuple:
163171
"""Read total DOS data file
@@ -171,7 +179,7 @@ def _read(self) -> tuple:
171179
self.energy, self.dos = np.split(data, self.nspin+1, axis=1)
172180
elif self.nspin == 2:
173181
self.energy, dos_up, dos_dw = np.split(data, self.nspin+1, axis=1)
174-
self.dos = np.hstack(dos_up, dos_dw)
182+
self.dos = np.hstack((dos_up, dos_dw))
175183

176184
def _shift_energy(self, efermi: float = 0, shift: bool = False, prec: float = 0.01):
177185
if shift:
@@ -188,7 +196,7 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], efermi: f
188196

189197
energy_f = self._shift_energy(efermi, shift, prec)
190198

191-
dosplot = DOSPlot(fig, ax, **kwargs)
199+
dosplot = DOSPlot(fig, ax, self.nspin, **kwargs)
192200
dosplot.ax = self._plot(dosplot, energy_f, self.dos, "TDOS")
193201
if "notes" in dosplot.plot_params.keys():
194202
dosplot._set_figure(energy_range, dos_range,
@@ -203,9 +211,9 @@ class PDOS(DOS):
203211
"""Parse partial DOS data"""
204212

205213
def __init__(self, pdosfile: PathLike = None) -> None:
206-
super().__init__()
207214
self.pdosfile = pdosfile
208215
self._read()
216+
super().__init__(self.nspin)
209217

210218
def _read(self):
211219
"""Read partial DOS data file
@@ -398,7 +406,7 @@ def _parial_plot(self,
398406
energy_f, tdos = self._shift_energy(efermi, shift, prec)
399407

400408
if not species:
401-
dosplot = DOSPlot(fig, ax, **kwargs)
409+
dosplot = DOSPlot(fig, ax, self.nspin, **kwargs)
402410
dosplot.ax = self._plot(dosplot, energy_f, tdos, "TDOS")
403411
if "notes" in dosplot.plot_params.keys():
404412
dosplot._set_figure(energy_range, dos_range,
@@ -409,7 +417,7 @@ def _parial_plot(self,
409417
return dosplot
410418

411419
if isinstance(species, (list, tuple)):
412-
dosplot = DOSPlot(fig, ax, **kwargs)
420+
dosplot = DOSPlot(fig, ax, self.nspin, **kwargs)
413421
if "xlabel_params" in dosplot.plot_params.keys():
414422
dosplot.ax.set_xlabel("Energy(eV)", **
415423
dosplot.plot_params["xlabel_params"])
@@ -431,7 +439,7 @@ def _parial_plot(self,
431439
assert len(ax) >= len(
432440
dos.keys()), "There must be enough `axes` to plot."
433441
for i, elem in enumerate(dos.keys()):
434-
dosplot = DOSPlot(fig, ax[i], **kwargs)
442+
dosplot = DOSPlot(fig, ax[i], self.nspin, **kwargs)
435443
for ang in dos[elem].keys():
436444
l_index = int(ang)
437445
if isinstance(dos[elem][ang], dict):

0 commit comments

Comments
 (0)