# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Dict, Literal, Optional, Sequence, Tuple
import numpy as np
from sisl.viz.types import OrbitalQueries, StyleSpec
from ..data.bands import BandsData
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.plot_actions import combined
from ..plotters.xarray import draw_xarray_xy
from ..plotutils import random_color
from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands
from ..processors.data import accept_data
from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data
from ..processors.xarray import scale_variable
from .orbital_groups_plot import OrbitalGroupsPlot
def bands_plot(
bands_data: BandsData,
Erange: Optional[Tuple[float, float]] = None,
E0: float = 0.0,
E_axis: Literal["x", "y"] = "y",
bands_range: Optional[Tuple[int, int]] = None,
spin: Optional[Literal[0, 1]] = None,
bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1},
spindown_style: StyleSpec = {"color": "blue", "width": 1},
colorscale: Optional[str] = None,
gap: bool = False,
gap_tol: float = 0.01,
gap_color: str = "red",
gap_marker: dict = {"size": 7},
direct_gaps_only: bool = False,
custom_gaps: Sequence[Dict] = [],
line_mode: Literal["line", "scatter", "area_line"] = "line",
backend: str = "plotly",
) -> Figure:
"""Plots band structure energies, with plentiful of customization options.
Parameters
----------
bands_data:
The object containing the data to plot.
Erange:
The energy range to plot.
If None, the range is determined by ``bands_range``.
E0:
The energy reference.
E_axis:
Axis to plot the energies.
bands_range:
The bands to plot. Only used if ``Erange`` is None.
If None, the 15 bands above and below the Fermi level are plotted.
spin:
Which spin channel to display. Only meaningful for spin-polarized calculations.
If None and the calculation is spin polarized, both are plotted.
bands_style:
Styling attributes for bands.
spindown_style:
Styling attributes for the spin down bands (if present). Any missing attribute
will be taken from ``bands_style``.
colorscale:
Colorscale to use for the bands in case the color attribute is an array of values.
If None, the default colorscale is used for each backend.
gap:
Whether to display the gap.
gap_tol:
Tolerance in k for determining whether two gaps are the same.
gap_color:
Color of the gap.
gap_marker:
Marker styles for the gap (as plotly marker's styles).
direct_gaps_only:
Whether to only display direct gaps.
custom_gaps:
List of custom gaps to display. See the showcase notebooks for examples.
line_mode:
The method used to draw the band lines.
backend:
The backend to use to generate the figure.
"""
bands_data = accept_data(bands_data, cls=BandsData, check=True)
# Filter the bands
filtered_bands = filter_bands(
bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin
)
# Add the styles
styled_bands = style_bands(
filtered_bands, bands_style=bands_style, spindown_style=spindown_style
)
# Determine what goes on each axis
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"
# Get the actions to plot lines
bands_plottings = draw_xarray_xy(
data=styled_bands,
x=x,
y=y,
set_axrange=True,
what=line_mode,
colorscale=colorscale,
dependent_axis=E_axis,
)
# Gap calculation
gap_info = calculate_gap(filtered_bands)
# Plot it if the user has asked for it.
gaps_plottings = draw_gaps(
bands_data,
gap,
gap_info,
gap_tol,
gap_color,
gap_marker,
direct_gaps_only,
custom_gaps,
E_axis=E_axis,
)
all_plottings = combined(bands_plottings, gaps_plottings, composite_method=None)
return get_figure(backend=backend, plot_actions=all_plottings)
def _default_random_color(x):
return x.get("color") or random_color()
def _group_traces(actions):
seen_groups = []
new_actions = []
for action in actions:
if action["method"].startswith("draw_"):
group = action["kwargs"].get("name")
action = action.copy()
action["kwargs"]["legendgroup"] = group
if group in seen_groups:
action["kwargs"]["showlegend"] = False
else:
seen_groups.append(group)
new_actions.append(action)
return new_actions
# I keep the fatbands plot here so that one can see how similar they are.
# I am yet to find a nice solution for extending workflows.
def fatbands_plot(
bands_data: BandsData,
Erange: Optional[Tuple[float, float]] = None,
E0: float = 0.0,
E_axis: Literal["x", "y"] = "y",
bands_range: Optional[Tuple[int, int]] = None,
spin: Optional[Literal[0, 1]] = None,
bands_style: StyleSpec = {"color": "black", "width": 1, "opacity": 1},
spindown_style: StyleSpec = {"color": "blue", "width": 1},
gap: bool = False,
gap_tol: float = 0.01,
gap_color: str = "red",
gap_marker: dict = {"size": 7},
direct_gaps_only: bool = False,
custom_gaps: Sequence[Dict] = [],
bands_mode: Literal["line", "scatter", "area_line"] = "line",
# Fatbands inputs
groups: OrbitalQueries = [],
fatbands_var: str = "norm2",
fatbands_mode: Literal["line", "scatter", "area_line"] = "area_line",
fatbands_scale: float = 1.0,
backend: str = "plotly",
) -> Figure:
"""Plots band structure energies showing the contribution of orbitals to each state.
Parameters
----------
bands_data:
The object containing the data to plot.
Erange:
The energy range to plot.
If None, the range is determined by ``bands_range``.
E0:
The energy reference.
E_axis:
Axis to plot the energies.
bands_range:
The bands to plot. Only used if ``Erange`` is None.
If None, the 15 bands above and below the Fermi level are plotted.
spin:
Which spin channel to display. Only meaningful for spin-polarized calculations.
If None and the calculation is spin polarized, both are plotted.
bands_style:
Styling attributes for bands.
spindown_style:
Styling attributes for the spin down bands (if present). Any missing attribute
will be taken from ``bands_style``.
gap:
Whether to display the gap.
gap_tol:
Tolerance in k for determining whether two gaps are the same.
gap_color:
Color of the gap.
gap_marker:
Marker styles for the gap (as plotly marker's styles).
direct_gaps_only:
Whether to only display direct gaps.
custom_gaps:
List of custom gaps to display. See the showcase notebooks for examples.
bands_mode:
The method used to draw the band lines.
groups:
Orbital groups to plots. See showcase notebook for examples.
fatbands_var:
The variable to use from bands_data to determine the width of the fatbands.
This variable must have as coordinates (k, band, orb, [spin]).
fatbands_mode:
The method used to draw the fatbands.
fatbands_scale:
Factor that scales the size of all fatbands.
backend:
The backend to use to generate the figure.
"""
bands_data = accept_data(bands_data, cls=BandsData, check=True)
# Filter the bands
filtered_bands = filter_bands(
bands_data, Erange=Erange, E0=E0, bands_range=bands_range, spin=spin
)
# Add the styles
styled_bands = style_bands(
filtered_bands, bands_style=bands_style, spindown_style=spindown_style
)
# Process fatbands
orbital_manager = get_orbital_queries_manager(
bands_data,
key_gens={
"color": _default_random_color,
},
)
fatbands_data = reduce_orbital_data(
filtered_bands,
groups=groups,
orb_dim="orb",
spin_dim="spin",
sanitize_group=orbital_manager,
group_vars=("color", "dash"),
groups_dim="group",
drop_empty=True,
spin_reduce=np.sum,
)
scaled_fatbands_data = scale_variable(
fatbands_data,
var=fatbands_var,
scale=fatbands_scale,
default_value=1,
allow_not_present=True,
)
# Determine what goes on each axis
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"
sanitized_fatbands_mode = "none" if groups == [] else fatbands_mode
# Get the actions to plot lines
fatbands_plottings = draw_xarray_xy(
data=scaled_fatbands_data,
x=x,
y=y,
color="color",
width=fatbands_var,
what=sanitized_fatbands_mode,
dependent_axis=E_axis,
name="group",
)
grouped_fatbands_plottings = _group_traces(fatbands_plottings)
bands_plottings = draw_xarray_xy(
data=styled_bands,
x=x,
y=y,
set_axrange=True,
what=bands_mode,
dependent_axis=E_axis,
)
# Gap calculation
gap_info = calculate_gap(filtered_bands)
# Plot it if the user has asked for it.
gaps_plottings = draw_gaps(
bands_data,
gap,
gap_info,
gap_tol,
gap_color,
gap_marker,
direct_gaps_only,
custom_gaps,
E_axis=E_axis,
)
all_plottings = combined(
grouped_fatbands_plottings,
bands_plottings,
gaps_plottings,
composite_method=None,
)
return get_figure(backend=backend, plot_actions=all_plottings)
[docs]
class BandsPlot(Plot):
function = staticmethod(bands_plot)
[docs]
class FatbandsPlot(OrbitalGroupsPlot):
function = staticmethod(fatbands_plot)