Source code for ase.dft.band_structure

import numpy as np

from ase.dft.kpoints import labels_from_kpts
from ase.io.jsonio import encode, decode
from ase.parallel import paropen


def get_band_structure(atoms=None, calc=None):
    """Create band structure object from Atoms or calculator."""
    atoms = atoms if atoms is not None else calc.atoms
    calc = calc if calc is not None else atoms.calc

    kpts = calc.get_ibz_k_points()

    energies = []
    for s in range(calc.get_number_of_spins()):
        energies.append([calc.get_eigenvalues(kpt=k, spin=s)
                         for k in range(len(kpts))])
    energies = np.array(energies)

    return BandStructure(cell=atoms.cell,
                         kpts=kpts,
                         energies=energies,
                         reference=calc.get_fermi_level())


class BandStructurePlot:
    def __init__(self, bs):
        self.bs = bs
        self.ax = None
        self.xcoords = None
        self.show_legend = False

    def plot(self, ax=None, spin=None, emin=-10, emax=5, filename=None,
             show=None, ylabel=None, colors=None, label=None,
             spin_labels=['spin up', 'spin down'], loc=None, **plotkwargs):
        """Plot band-structure.

        spin: int or None
            Spin channel.  Default behaviour is to plot both spin up and down
            for spin-polarized calculations.
        emin,emax: float
            Maximum energy above reference.
        filename: str
            Write image to a file.
        ax: Axes
            MatPlotLib Axes object.  Will be created if not supplied.
        show: bool
            Show the image.
        """

        if self.ax is None:
            ax = self.prepare_plot(ax, emin, emax, ylabel)

        if spin is None:
            e_skn = self.bs.energies
        else:
            e_skn = self.bs.energies[spin, np.newaxis]

        if colors is None:
            if len(e_skn) == 1:
                colors = 'g'
            else:
                colors = 'yb'

        nspins = len(e_skn)

        for spin, e_kn in enumerate(e_skn):
            color = colors[spin]
            kwargs = dict(color=color)
            kwargs.update(plotkwargs)
            if nspins == 2:
                if label:
                    lbl = label + ' ' + spin_labels[spin]
                else:
                    lbl = spin_labels[spin]
            else:
                lbl = label
            ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs)
            for e_k in e_kn.T[1:]:
                ax.plot(self.xcoords, e_k, **kwargs)

        self.show_legend = label is not None or nspins == 2
        self.finish_plot(filename, show, loc)

        return ax

    def plot_with_colors(self, ax=None, emin=-10, emax=5, filename=None,
                         show=None, energies=None, colors=None,
                         ylabel=None, clabel='$s_z$', cmin=-1.0, cmax=1.0,
                         sortcolors=False, loc=None, s=2):
        """Plot band-structure with colors."""

        import matplotlib.pyplot as plt

        if self.ax is None:
            ax = self.prepare_plot(ax, emin, emax, ylabel)

        shape = energies.shape
        xcoords = np.vstack([self.xcoords] * shape[1])
        if sortcolors:
            perm = colors.argsort(axis=None)
            energies = energies.ravel()[perm].reshape(shape)
            colors = colors.ravel()[perm].reshape(shape)
            xcoords = xcoords.ravel()[perm].reshape(shape)

        for e_k, c_k, x_k in zip(energies, colors, xcoords):
            things = ax.scatter(x_k, e_k, c=c_k, s=s,
                                vmin=cmin, vmax=cmax)

        cbar = plt.colorbar(things)
        cbar.set_label(clabel)

        self.finish_plot(filename, show, loc)

        return ax

    def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None):
        import matplotlib.pyplot as plt
        if ax is None:
            ax = plt.figure().add_subplot(111)

        def pretty(kpt):
            if kpt == 'G':
                kpt = r'$\Gamma$'
            elif len(kpt) == 2:
                kpt = kpt[0] + '$_' + kpt[1] + '$'
            return kpt

        emin += self.bs.reference
        emax += self.bs.reference

        self.xcoords, label_xcoords, orig_labels = self.bs.get_labels()

        labels = [pretty(name) for name in orig_labels]
        i = 1
        while i < len(labels):
            if label_xcoords[i - 1] == label_xcoords[i]:
                labels[i - 1] = labels[i - 1][:-1] + ',' + labels[i][1:]
                labels[i] = ''
            i += 1

        for x in label_xcoords[1:-1]:
            ax.axvline(x, color='0.5')

        ylabel = ylabel if ylabel is not None else 'energies [eV]'

        ax.set_xticks(label_xcoords)
        ax.set_xticklabels(labels)
        ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax)
        ax.set_ylabel(ylabel)
        ax.axhline(self.bs.reference, color='k', ls=':')
        self.ax = ax
        return ax

    def finish_plot(self, filename, show, loc):
        import matplotlib.pyplot as plt

        if self.show_legend:
            leg = plt.legend(loc=loc)
            leg.get_frame().set_alpha(1)

        if filename:
            plt.savefig(filename)

        if show is None:
            show = not filename

        if show:
            plt.show()


[docs]class BandStructure: def __init__(self, cell, kpts, energies, reference=0.0): """Create band structure object from energies and k-points.""" assert cell.shape == (3, 3) self.cell = cell assert kpts.shape[1] == 3 self.kpts = kpts self.energies = np.asarray(energies) self.reference = reference def get_labels(self): return labels_from_kpts(self.kpts, self.cell) def todict(self): return dict((key, getattr(self, key)) for key in ['cell', 'kpts', 'energies', 'reference'])
[docs] def write(self, filename): """Write to json file.""" with paropen(filename, 'w') as f: f.write(encode(self))
[docs] @staticmethod def read(filename): """Read from json file.""" with open(filename, 'r') as f: dct = decode(f.read()) return BandStructure(**dct)
def plot(self, *args, **kwargs): bsp = BandStructurePlot(self) # Maybe return bsp? But for now run the plot, for compatibility return bsp.plot(*args, **kwargs)