Source code for sisl.viz.plots.bands

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import Dict, Literal, Optional, Sequence, Tuple

import numpy as np

from sisl.viz.types import OrbitalQueries, StyleSpec

from ..data.bands import BandsData
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.plot_actions import combined
from ..plotters.xarray import draw_xarray_xy
from ..plotutils import random_color
from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands
from ..processors.data import accept_data
from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data
from ..processors.xarray import scale_variable
from .orbital_groups_plot import OrbitalGroupsPlot


def bands_plot(
    bands_data: BandsData,
    Erange: Optional[Tuple[float, float]] = None,
    E0: float = 0.0,
    E_axis: Literal["x", "y"] = "y",
    bands_range: Optional[Tuple[int, int]] = None,
    spin: Optional[Literal[0, 1]] = None,
    bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1},
    spindown_style: StyleSpec = {"color": "blue", "width": 1},
    colorscale: Optional[str] = None,
    gap: bool = False,
    gap_tol: float = 0.01,
    gap_color: str = "red",
    gap_marker: dict = {"size": 7},
    direct_gaps_only: bool = False,
    custom_gaps: Sequence[Dict] = [],
    line_mode: Literal["line", "scatter", "area_line"] = "line",
    backend: str = "plotly",
) -> Figure:
    """Plots band structure energies, with plentiful of customization options.

    Parameters
    ----------
    bands_data:
        The object containing the data to plot.
    Erange:
        The energy range to plot.
        If None, the range is determined by ``bands_range``.
    E0:
        The energy reference.
    E_axis:
        Axis to plot the energies.
    bands_range:
        The bands to plot. Only used if ``Erange`` is None.
        If None, the 15 bands above and below the Fermi level are plotted.
    spin:
        Which spin channel to display. Only meaningful for spin-polarized calculations.
        If None and the calculation is spin polarized, both are plotted.
    bands_style:
        Styling attributes for bands.
    spindown_style:
        Styling attributes for the spin down bands (if present). Any missing attribute
        will be taken from ``bands_style``.
    colorscale:
        Colorscale to use for the bands in case the color attribute is an array of values.
        If None, the default colorscale is used for each backend.
    gap:
        Whether to display the gap.
    gap_tol:
        Tolerance in k for determining whether two gaps are the same.
    gap_color:
        Color of the gap.
    gap_marker:
        Marker styles for the gap (as plotly marker's styles).
    direct_gaps_only:
        Whether to only display direct gaps.
    custom_gaps:
        List of custom gaps to display. See the showcase notebooks for examples.
    line_mode:
        The method used to draw the band lines.
    backend:
        The backend to use to generate the figure.
    """

    bands_data = accept_data(bands_data, cls=BandsData, check=True)

    # Filter the bands
    filtered_bands = filter_bands(
        bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin
    )

    # Add the styles
    styled_bands = style_bands(
        filtered_bands, bands_style=bands_style, spindown_style=spindown_style
    )

    # Determine what goes on each axis
    x = "E" if E_axis == "x" else "k"
    y = "E" if E_axis == "y" else "k"

    # Get the actions to plot lines
    bands_plottings = draw_xarray_xy(
        data=styled_bands,
        x=x,
        y=y,
        set_axrange=True,
        what=line_mode,
        colorscale=colorscale,
        dependent_axis=E_axis,
    )

    # Gap calculation
    gap_info = calculate_gap(filtered_bands)
    # Plot it if the user has asked for it.
    gaps_plottings = draw_gaps(
        bands_data,
        gap,
        gap_info,
        gap_tol,
        gap_color,
        gap_marker,
        direct_gaps_only,
        custom_gaps,
        E_axis=E_axis,
    )

    all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None)

    return get_figure(backend=backend, plot_actions=all_plottings)


def _default_random_color(x):
    return x.get("color") or random_color()


def _group_traces(actions):
    seen_groups = []

    new_actions = []
    for action in actions:
        if action["method"].startswith("draw_"):
            group = action["kwargs"].get("name")
            action = action.copy()
            action["kwargs"]["legendgroup"] = group

            if group in seen_groups:
                action["kwargs"]["showlegend"] = False
            else:
                seen_groups.append(group)

        new_actions.append(action)

    return new_actions


# I keep the fatbands plot here so that one can see how similar they are.
# I am yet to find a nice solution for extending workflows.
def fatbands_plot(
    bands_data: BandsData,
    Erange: Optional[Tuple[float, float]] = None,
    E0: float = 0.0,
    E_axis: Literal["x", "y"] = "y",
    bands_range: Optional[Tuple[int, int]] = None,
    spin: Optional[Literal[0, 1]] = None,
    bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1},
    spindown_style: StyleSpec = {"color": "blue", "width": 1},
    gap: bool = False,
    gap_tol: float = 0.01,
    gap_color: str = "red",
    gap_marker: dict = {"size": 7},
    direct_gaps_only: bool = False,
    custom_gaps: Sequence[Dict] = [],
    bands_mode: Literal["line", "scatter", "area_line"] = "line",
    # Fatbands inputs
    groups: OrbitalQueries = [],
    fatbands_var: str = "norm2",
    fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line",
    fatbands_scale: float = 1.0,
    backend: str = "plotly",
) -> Figure:
    """Plots band structure energies showing the contribution of orbitals to each state.

    Parameters
    ----------
    bands_data:
        The object containing the data to plot.
    Erange:
        The energy range to plot.
        If None, the range is determined by ``bands_range``.
    E0:
        The energy reference.
    E_axis:
        Axis to plot the energies.
    bands_range:
        The bands to plot. Only used if ``Erange`` is None.
        If None, the 15 bands above and below the Fermi level are plotted.
    spin:
        Which spin channel to display. Only meaningful for spin-polarized calculations.
        If None and the calculation is spin polarized, both are plotted.
    bands_style:
        Styling attributes for bands.
    spindown_style:
        Styling attributes for the spin down bands (if present). Any missing attribute
        will be taken from ``bands_style``.
    gap:
        Whether to display the gap.
    gap_tol:
        Tolerance in k for determining whether two gaps are the same.
    gap_color:
        Color of the gap.
    gap_marker:
        Marker styles for the gap (as plotly marker's styles).
    direct_gaps_only:
        Whether to only display direct gaps.
    custom_gaps:
        List of custom gaps to display. See the showcase notebooks for examples.
    bands_mode:
        The method used to draw the band lines.
    groups:
        Orbital groups to plots. See showcase notebook for examples.
    fatbands_var:
        The variable to use from bands_data to determine the width of the fatbands.
        This variable must have as coordinates (k, band, orb, [spin]).
    fatbands_mode:
        The method used to draw the fatbands.
    fatbands_scale:
        Factor that scales the size of all fatbands.
    backend:
        The backend to use to generate the figure.
    """
    bands_data = accept_data(bands_data, cls=BandsData, check=True)

    # Filter the bands
    filtered_bands = filter_bands(
        bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin
    )

    # Add the styles
    styled_bands = style_bands(
        filtered_bands, bands_style=bands_style, spindown_style=spindown_style
    )

    # Process fatbands
    orbital_manager = get_orbital_queries_manager(
        bands_data,
        key_gens={
            "color": _default_random_color,
        },
    )
    fatbands_data = reduce_orbital_data(
        filtered_bands,
        groups=groups,
        orb_dim="orb",
        spin_dim="spin",
        sanitize_group=orbital_manager,
        group_vars=("color", "dash"),
        groups_dim="group",
        drop_empty=True,
        spin_reduce=np.sum,
    )
    scaled_fatbands_data = scale_variable(
        fatbands_data,
        var=fatbands_var,
        scale=fatbands_scale,
        default_value=1,
        allow_not_present=True,
    )

    # Determine what goes on each axis
    x = "E" if E_axis == "x" else "k"
    y = "E" if E_axis == "y" else "k"

    sanitized_fatbands_mode = "none" if groups == [] else fatbands_mode

    # Get the actions to plot lines
    fatbands_plottings = draw_xarray_xy(
        data=scaled_fatbands_data,
        x=x,
        y=y,
        color="color",
        width=fatbands_var,
        what=sanitized_fatbands_mode,
        dependent_axis=E_axis,
        name="group",
    )
    grouped_fatbands_plottings = _group_traces(fatbands_plottings)
    bands_plottings = draw_xarray_xy(
        data=styled_bands,
        x=x,
        y=y,
        set_axrange=True,
        what=bands_mode,
        dependent_axis=E_axis,
    )

    # Gap calculation
    gap_info = calculate_gap(filtered_bands)
    # Plot it if the user has asked for it.
    gaps_plottings = draw_gaps(
        bands_data,
        gap,
        gap_info,
        gap_tol,
        gap_color,
        gap_marker,
        direct_gaps_only,
        custom_gaps,
        E_axis=E_axis,
    )

    all_plottings = combined(
        grouped_fatbands_plottings,
        bands_plottings,
        gaps_plottings,
        composite_method=None,
    )

    return get_figure(backend=backend, plot_actions=all_plottings)


[docs] class BandsPlot(Plot): function = staticmethod(bands_plot)
[docs] class FatbandsPlot(OrbitalGroupsPlot): function = staticmethod(fatbands_plot)