Source code for sisl.viz.plots.bands

from typing import Dict, Literal, Optional, Sequence, Tuple

import numpy as np

from sisl.viz.types import OrbitalQueries, StyleSpec

from ..data.bands import BandsData
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.plot_actions import combined
from ..plotters.xarray import draw_xarray_xy
from ..plotutils import random_color
from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands
from ..processors.data import accept_data
from ..processors.logic import matches
from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data
from ..processors.xarray import scale_variable
from .orbital_groups_plot import OrbitalGroupsPlot


def bands_plot(bands_data: BandsData,
    Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", 
    bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, 
    bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, 
    spindown_style: StyleSpec = {"color": "blue", "width": 1}, 
    colorscale: Optional[str] = None,
    gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, 
    custom_gaps: Sequence[Dict] = [],
    line_mode: Literal["line", "scatter", "area_line"] = "line", 
    backend: str = "plotly"
) -> Figure:
    """Plots band structure energies, with plentiful of customization options.

    Parameters
    ----------
    bands_data: 
        The object containing the data to plot.
    Erange:
        The energy range to plot.
        If None, the range is determined by ``bands_range``.
    E0:
        The energy reference.
    E_axis:
        Axis to plot the energies.
    bands_range:
        The bands to plot. Only used if ``Erange`` is None.
        If None, the 15 bands above and below the Fermi level are plotted.
    spin:
        Which spin channel to display. Only meaningful for spin-polarized calculations.
        If None and the calculation is spin polarized, both are plotted.
    bands_style:
        Styling attributes for bands.
    spindown_style:
        Styling attributes for the spin down bands (if present). Any missing attribute
        will be taken from ``bands_style``.
    colorscale:
        Colorscale to use for the bands in case the color attribute is an array of values.
        If None, the default colorscale is used for each backend.
    gap:
        Whether to display the gap.
    gap_tol:
        Tolerance in k for determining whether two gaps are the same.
    gap_color:
        Color of the gap.
    gap_marker:
        Marker styles for the gap (as plotly marker's styles).
    direct_gaps_only:
        Whether to only display direct gaps.
    custom_gaps:
        List of custom gaps to display. See the showcase notebooks for examples.
    line_mode:
        The method used to draw the band lines.
    backend:
        The backend to use to generate the figure.
    """

    bands_data = accept_data(bands_data, cls=BandsData, check=True)

    # Filter the bands
    filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin)

    # Add the styles
    styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style)

    # Determine what goes on each axis
    x = matches(E_axis, "x", ret_true="E", ret_false="k")
    y = matches(E_axis, "y", ret_true="E", ret_false="k")
    
    # Get the actions to plot lines
    bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=line_mode, colorscale=colorscale, dependent_axis=E_axis)

    # Gap calculation
    gap_info = calculate_gap(filtered_bands)
    # Plot it if the user has asked for it.
    gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis)

    all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None)

    return get_figure(backend=backend, plot_actions=all_plottings)

def _default_random_color(x):
    return x.get("color") or random_color()


def _group_traces(actions):

    seen_groups = []

    new_actions = []
    for action in actions:
        if action["method"].startswith("draw_"):
            group = action["kwargs"].get("name")
            action = action.copy()
            action['kwargs']['legendgroup'] = group

            if group in seen_groups:
                action["kwargs"]["showlegend"] = False
            else:
                seen_groups.append(group)
            
        new_actions.append(action)
    
    return new_actions


# I keep the fatbands plot here so that one can see how similar they are.
# I am yet to find a nice solution for extending workflows.
def fatbands_plot(bands_data: BandsData, 
    Erange: Optional[Tuple[float, float]] = None, E0: float = 0., E_axis: Literal["x", "y"] = "y", 
    bands_range: Optional[Tuple[int, int]] = None, spin: Optional[Literal[0, 1]] = None, 
    bands_style: StyleSpec = {'color': 'black', 'width': 1, "opacity": 1}, 
    spindown_style: StyleSpec = {"color": "blue", "width": 1}, 
    gap: bool = False, gap_tol: float = 0.01, gap_color: str = "red", gap_marker: dict = {"size": 7}, direct_gaps_only: bool = False, 
    custom_gaps: Sequence[Dict] = [],
    bands_mode: Literal["line", "scatter", "area_line"] = "line", 
    # Fatbands inputs
    groups: OrbitalQueries = [],
    fatbands_var: str = "norm2", 
    fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line",
    fatbands_scale: float = 1., 
    backend: str = "plotly"
) -> Figure:
    """Plots band structure energies showing the contribution of orbitals to each state.

    Parameters
    ----------
    bands_data: 
        The object containing the data to plot.
    Erange:
        The energy range to plot.
        If None, the range is determined by ``bands_range``.
    E0:
        The energy reference.
    E_axis:
        Axis to plot the energies.
    bands_range:
        The bands to plot. Only used if ``Erange`` is None.
        If None, the 15 bands above and below the Fermi level are plotted.
    spin:
        Which spin channel to display. Only meaningful for spin-polarized calculations.
        If None and the calculation is spin polarized, both are plotted.
    bands_style:
        Styling attributes for bands.
    spindown_style:
        Styling attributes for the spin down bands (if present). Any missing attribute
        will be taken from ``bands_style``.
    gap:
        Whether to display the gap.
    gap_tol:
        Tolerance in k for determining whether two gaps are the same.
    gap_color:
        Color of the gap.
    gap_marker:
        Marker styles for the gap (as plotly marker's styles).
    direct_gaps_only:
        Whether to only display direct gaps.
    custom_gaps:
        List of custom gaps to display. See the showcase notebooks for examples.
    bands_mode:
        The method used to draw the band lines.
    groups:
        Orbital groups to plots. See showcase notebook for examples.
    fatbands_var:
        The variable to use from bands_data to determine the width of the fatbands.
        This variable must have as coordinates (k, band, orb, [spin]).
    fatbands_mode:
        The method used to draw the fatbands.
    fatbands_scale:
        Factor that scales the size of all fatbands.
    backend:
        The backend to use to generate the figure.
    """
    bands_data = accept_data(bands_data, cls=BandsData, check=True)

    # Filter the bands
    filtered_bands = filter_bands(bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin)

    # Add the styles
    styled_bands = style_bands(filtered_bands, bands_style=bands_style, spindown_style=spindown_style)

    # Process fatbands
    orbital_manager = get_orbital_queries_manager(
        bands_data,
        key_gens={
            "color": _default_random_color,
        }
    )
    fatbands_data = reduce_orbital_data(
        filtered_bands, groups=groups, orb_dim="orb", spin_dim="spin", sanitize_group=orbital_manager,
        group_vars=('color', 'dash'), groups_dim="group", drop_empty=True,
        spin_reduce=np.sum,
    )
    scaled_fatbands_data = scale_variable(fatbands_data, var=fatbands_var, scale=fatbands_scale, default_value=1, allow_not_present=True)

    # Determine what goes on each axis
    x = matches(E_axis, "x", ret_true="E", ret_false="k")
    y = matches(E_axis, "y", ret_true="E", ret_false="k")

    sanitized_fatbands_mode = matches(groups, [], ret_true="none", ret_false=fatbands_mode)
    
    # Get the actions to plot lines
    fatbands_plottings = draw_xarray_xy(
        data=scaled_fatbands_data, x=x, y=y, color="color", width=fatbands_var, what=sanitized_fatbands_mode, dependent_axis=E_axis,
        name="group"
    )
    grouped_fatbands_plottings = _group_traces(fatbands_plottings)
    bands_plottings = draw_xarray_xy(data=styled_bands, x=x, y=y, set_axrange=True, what=bands_mode, dependent_axis=E_axis)

    # Gap calculation
    gap_info = calculate_gap(filtered_bands)
    # Plot it if the user has asked for it.
    gaps_plottings = draw_gaps(bands_data, gap, gap_info, gap_tol, gap_color, gap_marker, direct_gaps_only, custom_gaps, E_axis=E_axis)

    all_plottings = combined(grouped_fatbands_plottings, bands_plottings, gaps_plottings, composite_method=None)

    return get_figure(backend=backend, plot_actions=all_plottings)

[docs]class BandsPlot(Plot): function = staticmethod(bands_plot)
[docs]class FatbandsPlot(OrbitalGroupsPlot): function = staticmethod(fatbands_plot)