20
20
class DOS :
21
21
"""Parse DOS data"""
22
22
23
- def __init__ (self ) -> None :
24
- self .nspin = 1
23
+ def __init__ (self , nspin ) -> None :
24
+ self .nspin = nspin
25
25
if self .nspin in [1 , 4 ]:
26
26
self ._nsplit = 1
27
27
elif self .nspin == 2 :
@@ -89,10 +89,11 @@ def bandgap(cls, vb: namedtuple, cb: namedtuple):
89
89
return gap
90
90
91
91
92
- class DOSPlot :
92
+ class :
93
93
"""Plot density of state(DOS)"""
94
94
95
- def __init__ (self , fig : Figure = None , ax : axes .Axes = None , ** kwargs ) -> None :
95
+ def __init__ (self , fig : Figure = None , ax : axes .Axes = None , nspin : int = 1 , ** kwargs ) -> None :
96
+ self .nspin = nspin
96
97
self .fig = fig
97
98
self .ax = ax
98
99
self ._lw = kwargs .pop ('lw' , 2 )
@@ -145,19 +146,26 @@ def _set_figure(self, energy_range: Sequence = [], dos_range: Sequence = [], not
145
146
else :
146
147
self .ax .axvline (0 , linestyle = "--" , c = 'b' , lw = 1.0 )
147
148
149
+ if self .nspin == 2 :
150
+ self .ax .axhline (0 , linestyle = "--" , c = 'gray' , lw = 1.0 )
151
+
152
+ handles , labels = self .ax .get_legend_handles_labels ()
153
+ by_label = OrderedDict (zip (labels , handles ))
148
154
if "legend_prop" in self .plot_params .keys ():
149
- self .ax .legend (prop = self .plot_params ["legend_prop" ])
155
+ self .ax .legend (by_label .values (), by_label .keys (),
156
+ prop = self .plot_params ["legend_prop" ])
150
157
else :
151
- self .ax .legend (prop = {'size' : 15 })
158
+ self .ax .legend (by_label .values (),
159
+ by_label .keys (), prop = {'size' : 15 })
152
160
153
161
154
162
class TDOS (DOS ):
155
163
"""Parse total DOS data"""
156
164
157
165
def __init__ (self , tdosfile : PathLike = None ) -> None :
158
- super ().__init__ ()
159
166
self .tdosfile = tdosfile
160
167
self ._read ()
168
+ super ().__init__ (self .nspin )
161
169
162
170
def _read (self ) -> tuple :
163
171
"""Read total DOS data file
@@ -171,7 +179,7 @@ def _read(self) -> tuple:
171
179
self .energy , self .dos = np .split (data , self .nspin + 1 , axis = 1 )
172
180
elif self .nspin == 2 :
173
181
self .energy , dos_up , dos_dw = np .split (data , self .nspin + 1 , axis = 1 )
174
- self .dos = np .hstack (dos_up , dos_dw )
182
+ self .dos = np .hstack (( dos_up , dos_dw ) )
175
183
176
184
def _shift_energy (self , efermi : float = 0 , shift : bool = False , prec : float = 0.01 ):
177
185
if shift :
@@ -188,7 +196,7 @@ def plot(self, fig: Figure, ax: Union[axes.Axes, Sequence[axes.Axes]], efermi: f
188
196
189
197
energy_f = self ._shift_energy (efermi , shift , prec )
190
198
191
- dosplot = DOSPlot (fig , ax , ** kwargs )
199
+ dosplot = DOSPlot (fig , ax , self . nspin , ** kwargs )
192
200
dosplot .ax = self ._plot (dosplot , energy_f , self .dos , "TDOS" )
193
201
if "notes" in dosplot .plot_params .keys ():
194
202
dosplot ._set_figure (energy_range , dos_range ,
@@ -203,9 +211,9 @@ class PDOS(DOS):
203
211
"""Parse partial DOS data"""
204
212
205
213
def __init__ (self , pdosfile : PathLike = None ) -> None :
206
- super ().__init__ ()
207
214
self .pdosfile = pdosfile
208
215
self ._read ()
216
+ super ().__init__ (self .nspin )
209
217
210
218
def _read (self ):
211
219
"""Read partial DOS data file
@@ -398,7 +406,7 @@ def _parial_plot(self,
398
406
energy_f , tdos = self ._shift_energy (efermi , shift , prec )
399
407
400
408
if not species :
401
- dosplot = DOSPlot (fig , ax , ** kwargs )
409
+ dosplot = DOSPlot (fig , ax , self . nspin , ** kwargs )
402
410
dosplot .ax = self ._plot (dosplot , energy_f , tdos , "TDOS" )
403
411
if "notes" in dosplot .plot_params .keys ():
404
412
dosplot ._set_figure (energy_range , dos_range ,
@@ -409,7 +417,7 @@ def _parial_plot(self,
409
417
return dosplot
410
418
411
419
if isinstance (species , (list , tuple )):
412
- dosplot = DOSPlot (fig , ax , ** kwargs )
420
+ dosplot = DOSPlot (fig , ax , self . nspin , ** kwargs )
413
421
if "xlabel_params" in dosplot .plot_params .keys ():
414
422
dosplot .ax .set_xlabel ("Energy(eV)" , **
415
423
dosplot .plot_params ["xlabel_params" ])
@@ -431,7 +439,7 @@ def _parial_plot(self,
431
439
assert len (ax ) >= len (
432
440
dos .keys ()), "There must be enough `axes` to plot."
433
441
for i , elem in enumerate (dos .keys ()):
434
- dosplot = DOSPlot (fig , ax [i ], ** kwargs )
442
+ dosplot = DOSPlot (fig , ax [i ], self . nspin , ** kwargs )
435
443
for ang in dos [elem ].keys ():
436
444
l_index = int (ang )
437
445
if isinstance (dos [elem ][ang ], dict ):
0 commit comments