# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations
from typing import Callable, ChainMap, Literal, Optional, Sequence, Tuple, Union
from sisl._core import Geometry, Grid
from ..data import EigenstateData
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.cell import cell_plot_actions
from ..plotters.grid import draw_grid
from ..plotters.plot_actions import combined
from ..processors.axes import sanitize_axes
from ..processors.eigenstate import (
eigenstate_geometry,
get_eigenstate,
get_grid_nsc,
project_wavefunction,
tile_if_k,
)
from ..processors.grid import (
apply_transforms,
get_grid_axes,
get_grid_representation,
grid_geometry,
grid_to_dataarray,
interpolate_grid,
orthogonalize_grid_if_needed,
reduce_grid,
sub_grid,
tile_grid,
)
from ..types import Axes
from .geometry import geometry_plot
def _get_structure_plottings(
plot_geom,
geometry,
axes,
nsc,
geom_kwargs={},
):
if plot_geom:
geom_kwargs = ChainMap(
geom_kwargs,
{"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False},
)
plot_actions = geometry_plot(**geom_kwargs).plot_actions
else:
plot_actions = []
return plot_actions
def grid_plot(
grid: Optional[Grid] = None,
axes: Axes = ["z"],
represent: Literal[
"real", "imag", "mod", "phase", "deg_phase", "rad_phase"
] = "real",
transforms: Sequence[Union[str, Callable]] = (),
reduce_method: Literal["average", "sum"] = "average",
boundary_mode: str = "grid-wrap",
nsc: Tuple[int, int, int] = (1, 1, 1),
interp: Tuple[int, int, int] = (1, 1, 1),
isos: Sequence[dict] = [],
smooth: bool = False,
colorscale: Optional[str] = None,
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
show_cell: Literal["box", "axes", False] = "box",
cell_style: dict = {},
x_range: Optional[Sequence[float]] = None,
y_range: Optional[Sequence[float]] = None,
z_range: Optional[Sequence[float]] = None,
plot_geom: bool = False,
geom_kwargs: dict = {},
backend: str = "plotly",
) -> Figure:
"""Plots a grid, with plentiful of customization options.
Parameters
----------
grid:
The grid to plot.
axes:
The axes to project the grid to.
represent:
The representation of the grid to plot.
transforms:
List of transforms to apply to the grid before plotting.
reduce_method:
The method used to reduce the grid axes that are not displayed.
boundary_mode:
The method used to deal with the boundary conditions.
Only used if the grid is to be orthogonalized.
See scipy docs for more info on the possible values.
nsc:
The number of unit cells to display in each direction.
interp:
The interpolation factor to use for each axis to make the grid smoother.
isos:
List of isosurfaces or isocontours to plot. See the showcase notebooks for examples.
smooth:
Whether to ask the plotting backend to make an attempt at smoothing the grid display.
colorscale:
Colorscale to use for the grid display in the 2D representation.
If None, the default colorscale is used for each backend.
crange:
Min and max values for the colorscale.
cmid:
The value at which the colorscale is centered.
show_cell:
Method used to display the unit cell. If False, the cell is not displayed.
cell_style:
Style specification for the cell. See the showcase notebooks for examples.
x_range:
The range of the x axis to take into account.
Even if the X axis is not displayed! This is important because the reducing
operation will only be applied on this range.
y_range:
The range of the y axis to take into account.
Even if the Y axis is not displayed! This is important because the reducing
operation will only be applied on this range.
z_range:
The range of the z axis to take into account.
Even if the Z axis is not displayed! This is important because the reducing
operation will only be applied on this range.
plot_geom:
Whether to plot the associated geometry (if any).
geom_kwargs:
Keyword arguments to pass to the geometry plot of the associated geometry.
backend:
The backend to use to generate the figure.
See also
----------
scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed.
"""
axes = sanitize_axes(axes)
geometry = grid_geometry(grid, geometry=None)
grid_repr = get_grid_representation(grid, represent=represent)
tiled_grid = tile_grid(grid_repr, nsc=nsc)
ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode)
grid_axes = get_grid_axes(ort_grid, axes=axes)
transformed_grid = apply_transforms(ort_grid, transforms)
subbed_grid = sub_grid(
transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range
)
reduced_grid = reduce_grid(
subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes
)
interp_grid = interpolate_grid(reduced_grid, interp=interp)
# Finally, here comes the plotting!
grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=nsc)
grid_plottings = draw_grid(
data=grid_ds,
isos=isos,
colorscale=colorscale,
crange=crange,
cmid=cmid,
smooth=smooth,
)
# Process the cell as well
cell_plottings = cell_plot_actions(
cell=grid,
show_cell=show_cell,
cell_style=cell_style,
axes=axes,
)
# And maybe plot the strucuture
geom_plottings = _get_structure_plottings(
plot_geom=plot_geom,
geometry=geometry,
geom_kwargs=geom_kwargs,
axes=axes,
nsc=nsc,
)
all_plottings = combined(
grid_plottings, cell_plottings, geom_plottings, composite_method=None
)
return get_figure(backend=backend, plot_actions=all_plottings)
def wavefunction_plot(
eigenstate: EigenstateData,
i: int = 0,
geometry: Optional[Geometry] = None,
grid_prec: float = 0.2,
# All grid inputs.
grid: Optional[Grid] = None,
axes: Axes = ["z"],
represent: Literal[
"real", "imag", "mod", "phase", "deg_phase", "rad_phase"
] = "real",
transforms: Sequence[Union[str, Callable]] = (),
reduce_method: Literal["average", "sum"] = "average",
boundary_mode: str = "grid-wrap",
nsc: Tuple[int, int, int] = (1, 1, 1),
interp: Tuple[int, int, int] = (1, 1, 1),
isos: Sequence[dict] = [],
smooth: bool = False,
colorscale: Optional[str] = None,
crange: Optional[Tuple[float, float]] = None,
cmid: Optional[float] = None,
show_cell: Literal["box", "axes", False] = "box",
cell_style: dict = {},
x_range: Optional[Sequence[float]] = None,
y_range: Optional[Sequence[float]] = None,
z_range: Optional[Sequence[float]] = None,
plot_geom: bool = False,
geom_kwargs: dict = {},
backend: str = "plotly",
) -> Figure:
"""Plots a wavefunction in real space.
Parameters
----------
eigenstate:
The eigenstate object containing information about eigenstates.
i:
The index of the eigenstate to plot.
geometry:
Geometry to use to project the eigenstate to real space.
If None, the geometry associated with the eigenstate is used.
grid_prec:
The precision of the grid where the wavefunction is projected.
grid:
The grid to plot.
axes:
The axes to project the grid to.
represent:
The representation of the grid to plot.
transforms:
List of transforms to apply to the grid before plotting.
reduce_method:
The method used to reduce the grid axes that are not displayed.
boundary_mode:
The method used to deal with the boundary conditions.
Only used if the grid is to be orthogonalized.
See scipy docs for more info on the possible values.
nsc:
The number of unit cells to display in each direction.
interp:
The interpolation factor to use for each axis to make the grid smoother.
isos:
List of isosurfaces or isocontours to plot. See the showcase notebooks for examples.
smooth:
Whether to ask the plotting backend to make an attempt at smoothing the grid display.
colorscale:
Colorscale to use for the grid display in the 2D representation.
If None, the default colorscale is used for each backend.
crange:
Min and max values for the colorscale.
cmid:
The value at which the colorscale is centered.
show_cell:
Method used to display the unit cell. If False, the cell is not displayed.
cell_style:
Style specification for the cell. See the showcase notebooks for examples.
x_range:
The range of the x axis to take into account.
Even if the X axis is not displayed! This is important because the reducing
operation will only be applied on this range.
y_range:
The range of the y axis to take into account.
Even if the Y axis is not displayed! This is important because the reducing
operation will only be applied on this range.
z_range:
The range of the z axis to take into account.
Even if the Z axis is not displayed! This is important because the reducing
operation will only be applied on this range.
plot_geom:
Whether to plot the associated geometry (if any).
geom_kwargs:
Keyword arguments to pass to the geometry plot of the associated geometry.
backend:
The backend to use to generate the figure.
See also
----------
scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed.
"""
# Create a grid with the wavefunction in it.
i_eigenstate = get_eigenstate(eigenstate, i)
geometry = eigenstate_geometry(eigenstate, geometry=geometry)
tiled_geometry = tile_if_k(geometry=geometry, nsc=nsc, eigenstate=i_eigenstate)
grid_nsc = get_grid_nsc(nsc=nsc, eigenstate=i_eigenstate)
grid = project_wavefunction(
eigenstate=i_eigenstate, grid_prec=grid_prec, grid=grid, geometry=tiled_geometry
)
# Grid processing
axes = sanitize_axes(axes)
grid_repr = get_grid_representation(grid, represent=represent)
tiled_grid = tile_grid(grid_repr, nsc=grid_nsc)
ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode)
grid_axes = get_grid_axes(ort_grid, axes=axes)
transformed_grid = apply_transforms(ort_grid, transforms)
subbed_grid = sub_grid(
transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range
)
reduced_grid = reduce_grid(
subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes
)
interp_grid = interpolate_grid(reduced_grid, interp=interp)
# Finally, here comes the plotting!
grid_ds = grid_to_dataarray(
interp_grid, axes=axes, grid_axes=grid_axes, nsc=grid_nsc
)
grid_plottings = draw_grid(
data=grid_ds,
isos=isos,
colorscale=colorscale,
crange=crange,
cmid=cmid,
smooth=smooth,
)
# Process the cell as well
cell_plottings = cell_plot_actions(
cell=grid,
show_cell=show_cell,
cell_style=cell_style,
axes=axes,
)
# And maybe plot the strucuture
geom_plottings = _get_structure_plottings(
plot_geom=plot_geom,
geometry=tiled_geometry,
geom_kwargs=geom_kwargs,
axes=axes,
nsc=grid_nsc,
)
all_plottings = combined(
grid_plottings, cell_plottings, geom_plottings, composite_method=None
)
return get_figure(backend=backend, plot_actions=all_plottings)
[docs]
class GridPlot(Plot):
function = staticmethod(grid_plot)
[docs]
class WavefunctionPlot(GridPlot):
function = staticmethod(wavefunction_plot)
# The following commented code is from the old viz module, where the GridPlot had a scan method.
# It looks very nice, but probably should be reimplemented as a standalone function that plots a grid slice,
# and then merge those grid slices to create a scan.
# def scan(self, along, start=None, stop=None, step=None, num=None, breakpoints=None, mode="moving_slice", animation_kwargs=None, **kwargs):
# """
# Returns an animation containing multiple frames scaning along an axis.
# Parameters
# -----------
# along: {"x", "y", "z"}
# the axis along which the scan is performed. If not provided, it will scan along the axes that are not displayed.
# start: float, optional
# the starting value for the scan (in Angstrom).
# Make sure this value is inside the range of the unit cell, otherwise it will fail.
# stop: float, optional
# the last value of the scan (in Angstrom).
# Make sure this value is inside the range of the unit cell, otherwise it will fail.
# step: float, optional
# the distance between steps in Angstrom.
# If not provided and `num` is also not provided, it will default to 1 Ang.
# num: int , optional
# the number of steps that you want the scan to consist of.
# If `step` is passed, this argument is ignored.
# Note that the grid is only stored once, so having a big number of steps is not that big of a deal.
# breakpoints: array-like, optional
# the discrete points of the scan. To be used if you don't want regular steps.
# If the last step is exactly the length of the cell, it will be moved one dcell back to avoid errors.
# Note that if this parameter is passed, both `step` and `num` are ignored.
# mode: {"moving_slice", "as_is"}, optional
# the type of scan you want to see.
# "moving_slice" renders a volumetric scan where a slice moves through the grid.
# "as_is" renders each part of the scan as an animation frame.
# (therefore, "as_is" SUPPORTS SCANNING 1D, 2D AND 3D REPRESENTATIONS OF THE GRID, e.g. display the volume data for different ranges of z)
# animation_kwargs: dict, optional
# dictionary whose keys and values are directly passed to the animated method as kwargs and therefore
# end up being passed to animation initialization.
# **kwargs:
# the rest of settings that you want to apply to overwrite the existing ones.
# This settings apply to each plot and go directly to their initialization.
# Returns
# -------
# sisl.viz.plotly.Animation
# An animation representation of the scan
# """
# # Do some checks on the args provided
# if sum(1 for arg in (step, num, breakpoints) if arg is not None) > 1:
# raise ValueError(f"Only one of ('step', 'num', 'breakpoints') should be passed.")
# axes = self.inputs['axes']
# if mode == "as_is" and set(axes) - set(["x", "y", "z"]):
# raise ValueError("To perform a scan, the axes need to be cartesian. Please set the axes to a combination of 'x', 'y' and 'z'.")
# if self.grid.lattice.is_cartesian():
# grid = self.grid
# else:
# transform_bc = kwargs.pop("transform_bc", self.get_setting("transform_bc"))
# grid, transform_offset = self._transform_grid_cell(
# self.grid, mode=transform_bc, output_shape=self.grid.shape, cval=np.nan
# )
# kwargs["offset"] = transform_offset + kwargs.get("offset", self.get_setting("offset"))
# # We get the key that needs to be animated (we will divide the full range in frames)
# range_key = f"{along}_range"
# along_i = {"x": 0, "y": 1, "z": 2}[along]
# # Get the full range
# if start is not None and stop is not None:
# along_range = [start, stop]
# else:
# along_range = self.get_setting(range_key)
# if along_range is None:
# range_param = self.get_param(range_key)
# along_range = [range_param[f"inputField.params.{lim}"] for lim in ["min", "max"]]
# if start is not None:
# along_range[0] = start
# if stop is not None:
# along_range[1] = stop
# if breakpoints is None:
# if step is None and num is None:
# step = 1.0
# if step is None:
# step = (along_range[1] - along_range[0]) / num
# else:
# num = (along_range[1] - along_range[0]) // step
# # np.linspace will use the last point as a step (and we don't want it)
# # therefore we will add an extra step
# breakpoints = np.linspace(*along_range, int(num) + 1)
# if breakpoints[-1] == grid.cell[along_i, along_i]:
# breakpoints[-1] = grid.cell[along_i, along_i] - grid.dcell[along_i, along_i]
# if mode == "moving_slice":
# return self._moving_slice_scan(grid, along_i, breakpoints)
# elif mode == "as_is":
# return self._asis_scan(grid, range_key, breakpoints, animation_kwargs=animation_kwargs, **kwargs)
# def _asis_scan(self, grid, range_key, breakpoints, animation_kwargs=None, **kwargs):
# """
# Returns an animation containing multiple frames scaning along an axis.
# Parameters
# -----------
# range_key: {'x_range', 'y_range', 'z_range'}
# the key of the setting that is to be animated through the scan.
# breakpoints: array-like
# the discrete points of the scan
# animation_kwargs: dict, optional
# dictionary whose keys and values are directly passed to the animated method as kwargs and therefore
# end up being passed to animation initialization.
# **kwargs:
# the rest of settings that you want to apply to overwrite the existing ones.
# This settings apply to each plot and go directly to their initialization.
# Returns
# ----------
# scan: sisl Animation
# An animation representation of the scan
# """
# # Generate the plot using self as a template so that plots don't need
# # to read data, just process it and show it differently.
# # (If each plot read the grid, the memory requirements would be HUGE)
# scan = self.animated(
# {
# range_key: [[bp, breakpoints[i+1]] for i, bp in enumerate(breakpoints[:-1])]
# },
# fixed={**{key: val for key, val in self.settings.items() if key != range_key}, **kwargs, "grid": grid},
# frame_names=[f'{bp:2f}' for bp in breakpoints],
# **(animation_kwargs or {})
# )
# # Set all frames to the same colorscale, if it's a 2d representation
# if len(self.get_setting("axes")) == 2:
# cmin = 10**6; cmax = -10**6
# for scan_im in scan:
# c = getattr(scan_im.data[0], "value", scan_im.data[0].z)
# cmin = min(cmin, np.min(c))
# cmax = max(cmax, np.max(c))
# for scan_im in scan:
# scan_im.update_settings(crange=[cmin, cmax])
# scan.get_figure()
# scan.layout = self.layout
# return scan
# def _moving_slice_scan(self, grid, along_i, breakpoints):
# import plotly.graph_objs as go
# ax = along_i
# displayed_axes = [i for i in range(3) if i != ax]
# shape = np.array(grid.shape)[displayed_axes]
# cmin = np.min(grid.grid)
# cmax = np.max(grid.grid)
# x_ax, y_ax = displayed_axes
# x = np.linspace(0, grid.cell[x_ax, x_ax], grid.shape[x_ax])
# y = np.linspace(0, grid.cell[y_ax, y_ax], grid.shape[y_ax])
# fig = go.Figure(frames=[go.Frame(data=go.Surface(
# x=x, y=y,
# z=(bp * np.ones(shape)).T,
# surfacecolor=np.squeeze(grid.cross_section(grid.index(bp, ax), ax).grid).T,
# cmin=cmin, cmax=cmax,
# ),
# name=f'{bp:.2f}'
# )
# for bp in breakpoints])
# # Add data to be displayed before animation starts
# fig.add_traces(fig.frames[0].data)
# def frame_args(duration):
# return {
# "frame": {"duration": duration},
# "mode": "immediate",
# "fromcurrent": True,
# "transition": {"duration": duration, "easing": "linear"},
# }
# sliders = [
# {
# "pad": {"b": 10, "t": 60},
# "len": 0.9,
# "x": 0.1,
# "y": 0,
# "steps": [
# {
# "args": [[f.name], frame_args(0)],
# "label": str(k),
# "method": "animate",
# }
# for k, f in enumerate(fig.frames)
# ],
# }
# ]
# def ax_title(ax): return f'{["X", "Y", "Z"][ax]} axis [Ang]'
# # Layout
# fig.update_layout(
# title=f'Grid scan along {["X", "Y", "Z"][ax]} axis',
# width=600,
# height=600,
# scene=dict(
# xaxis=dict(title=ax_title(x_ax)),
# yaxis=dict(title=ax_title(y_ax)),
# zaxis=dict(autorange=True, title=ax_title(ax)),
# aspectmode="data",
# ),
# updatemenus = [
# {
# "buttons": [
# {
# "args": [None, frame_args(50)],
# "label": "▶", # play symbol
# "method": "animate",
# },
# {
# "args": [[None], frame_args(0)],
# "label": "◼", # pause symbol
# "method": "animate",
# },
# ],
# "direction": "left",
# "pad": {"r": 10, "t": 70},
# "type": "buttons",
# "x": 0.1,
# "y": 0,
# }
# ],
# sliders=sliders
# )
# # We need to add an invisible trace so that the z axis stays with the correct range
# fig.add_trace({"type": "scatter3d", "mode": "markers", "marker_size": 0.001, "x": [0, 0], "y": [0, 0], "z": [0, grid.cell[ax, ax]]})
# return fig