Source code for sisl.viz.plots.grid

from typing import Callable, ChainMap, Literal, Optional, Sequence, Tuple, Union

from sisl.geometry import Geometry
from sisl.grid import Grid

from ..data import EigenstateData
from ..figure import Figure, get_figure
from ..plot import Plot
from ..plotters.cell import cell_plot_actions
from ..plotters.grid import draw_grid
from ..plotters.plot_actions import combined
from ..processors.axes import sanitize_axes
from ..processors.eigenstate import (
    eigenstate_geometry,
    get_eigenstate,
    get_grid_nsc,
    project_wavefunction,
    tile_if_k,
)
from ..processors.grid import (
    apply_transforms,
    get_grid_axes,
    get_grid_representation,
    grid_geometry,
    grid_to_dataarray,
    interpolate_grid,
    orthogonalize_grid_if_needed,
    reduce_grid,
    sub_grid,
    tile_grid,
)
from ..types import Axes
from .geometry import geometry_plot


def _get_structure_plottings(
    plot_geom,
    geometry,
    axes,
    nsc,
    geom_kwargs={},
):
    if plot_geom:
        geom_kwargs = ChainMap(
            geom_kwargs,
            {"axes": axes, "geometry": geometry, "nsc": nsc, "show_cell": False},
        )
        plot_actions = geometry_plot(**geom_kwargs).plot_actions
    else:
        plot_actions = []

    return plot_actions


def grid_plot(
    grid: Optional[Grid] = None,
    axes: Axes = ["z"],
    represent: Literal[
        "real", "imag", "mod", "phase", "deg_phase", "rad_phase"
    ] = "real",
    transforms: Sequence[Union[str, Callable]] = (),
    reduce_method: Literal["average", "sum"] = "average",
    boundary_mode: str = "grid-wrap",
    nsc: Tuple[int, int, int] = (1, 1, 1),
    interp: Tuple[int, int, int] = (1, 1, 1),
    isos: Sequence[dict] = [],
    smooth: bool = False,
    colorscale: Optional[str] = None,
    crange: Optional[Tuple[float, float]] = None,
    cmid: Optional[float] = None,
    show_cell: Literal["box", "axes", False] = "box",
    cell_style: dict = {},
    x_range: Optional[Sequence[float]] = None,
    y_range: Optional[Sequence[float]] = None,
    z_range: Optional[Sequence[float]] = None,
    plot_geom: bool = False,
    geom_kwargs: dict = {},
    backend: str = "plotly",
) -> Figure:
    """Plots a grid, with plentiful of customization options.

    Parameters
    ----------
    grid:
        The grid to plot.
    axes:
        The axes to project the grid to.
    represent:
        The representation of the grid to plot.
    transforms:
        List of transforms to apply to the grid before plotting.
    reduce_method:
        The method used to reduce the grid axes that are not displayed.
    boundary_mode:
        The method used to deal with the boundary conditions.
        Only used if the grid is to be orthogonalized.
        See scipy docs for more info on the possible values.
    nsc:
        The number of unit cells to display in each direction.
    interp:
        The interpolation factor to use for each axis to make the grid smoother.
    isos:
        List of isosurfaces or isocontours to plot. See the showcase notebooks for examples.
    smooth:
        Whether to ask the plotting backend to make an attempt at smoothing the grid display.
    colorscale:
        Colorscale to use for the grid display in the 2D representation.
        If None, the default colorscale is used for each backend.
    crange:
        Min and max values for the colorscale.
    cmid:
        The value at which the colorscale is centered.
    show_cell:
        Method used to display the unit cell. If False, the cell is not displayed.
    cell_style:
        Style specification for the cell. See the showcase notebooks for examples.
    x_range:
        The range of the x axis to take into account.
        Even if the X axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    y_range:
        The range of the y axis to take into account.
        Even if the Y axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    z_range:
        The range of the z axis to take into account.
        Even if the Z axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    plot_geom:
        Whether to plot the associated geometry (if any).
    geom_kwargs:
        Keyword arguments to pass to the geometry plot of the associated geometry.
    backend:
        The backend to use to generate the figure.

    See also
    ----------
    scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed.
    """

    axes = sanitize_axes(axes)

    geometry = grid_geometry(grid, geometry=None)

    grid_repr = get_grid_representation(grid, represent=represent)

    tiled_grid = tile_grid(grid_repr, nsc=nsc)

    ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode)

    grid_axes = get_grid_axes(ort_grid, axes=axes)

    transformed_grid = apply_transforms(ort_grid, transforms)

    subbed_grid = sub_grid(
        transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range
    )

    reduced_grid = reduce_grid(
        subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes
    )

    interp_grid = interpolate_grid(reduced_grid, interp=interp)

    # Finally, here comes the plotting!
    grid_ds = grid_to_dataarray(interp_grid, axes=axes, grid_axes=grid_axes, nsc=nsc)
    grid_plottings = draw_grid(
        data=grid_ds,
        isos=isos,
        colorscale=colorscale,
        crange=crange,
        cmid=cmid,
        smooth=smooth,
    )

    # Process the cell as well
    cell_plottings = cell_plot_actions(
        cell=grid,
        show_cell=show_cell,
        cell_style=cell_style,
        axes=axes,
    )

    # And maybe plot the strucuture
    geom_plottings = _get_structure_plottings(
        plot_geom=plot_geom,
        geometry=geometry,
        geom_kwargs=geom_kwargs,
        axes=axes,
        nsc=nsc,
    )

    all_plottings = combined(
        grid_plottings, cell_plottings, geom_plottings, composite_method=None
    )

    return get_figure(backend=backend, plot_actions=all_plottings)


def wavefunction_plot(
    eigenstate: EigenstateData,
    i: int = 0,
    geometry: Optional[Geometry] = None,
    grid_prec: float = 0.2,
    # All grid inputs.
    grid: Optional[Grid] = None,
    axes: Axes = ["z"],
    represent: Literal[
        "real", "imag", "mod", "phase", "deg_phase", "rad_phase"
    ] = "real",
    transforms: Sequence[Union[str, Callable]] = (),
    reduce_method: Literal["average", "sum"] = "average",
    boundary_mode: str = "grid-wrap",
    nsc: Tuple[int, int, int] = (1, 1, 1),
    interp: Tuple[int, int, int] = (1, 1, 1),
    isos: Sequence[dict] = [],
    smooth: bool = False,
    colorscale: Optional[str] = None,
    crange: Optional[Tuple[float, float]] = None,
    cmid: Optional[float] = None,
    show_cell: Literal["box", "axes", False] = "box",
    cell_style: dict = {},
    x_range: Optional[Sequence[float]] = None,
    y_range: Optional[Sequence[float]] = None,
    z_range: Optional[Sequence[float]] = None,
    plot_geom: bool = False,
    geom_kwargs: dict = {},
    backend: str = "plotly",
) -> Figure:
    """Plots a wavefunction in real space.

    Parameters
    ----------
    eigenstate:
        The eigenstate object containing information about eigenstates.
    i:
        The index of the eigenstate to plot.
    geometry:
        Geometry to use to project the eigenstate to real space.
        If None, the geometry associated with the eigenstate is used.
    grid_prec:
        The precision of the grid where the wavefunction is projected.
    grid:
        The grid to plot.
    axes:
        The axes to project the grid to.
    represent:
        The representation of the grid to plot.
    transforms:
        List of transforms to apply to the grid before plotting.
    reduce_method:
        The method used to reduce the grid axes that are not displayed.
    boundary_mode:
        The method used to deal with the boundary conditions.
        Only used if the grid is to be orthogonalized.
        See scipy docs for more info on the possible values.
    nsc:
        The number of unit cells to display in each direction.
    interp:
        The interpolation factor to use for each axis to make the grid smoother.
    isos:
        List of isosurfaces or isocontours to plot. See the showcase notebooks for examples.
    smooth:
        Whether to ask the plotting backend to make an attempt at smoothing the grid display.
    colorscale:
        Colorscale to use for the grid display in the 2D representation.
        If None, the default colorscale is used for each backend.
    crange:
        Min and max values for the colorscale.
    cmid:
        The value at which the colorscale is centered.
    show_cell:
        Method used to display the unit cell. If False, the cell is not displayed.
    cell_style:
        Style specification for the cell. See the showcase notebooks for examples.
    x_range:
        The range of the x axis to take into account.
        Even if the X axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    y_range:
        The range of the y axis to take into account.
        Even if the Y axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    z_range:
        The range of the z axis to take into account.
        Even if the Z axis is not displayed! This is important because the reducing
        operation will only be applied on this range.
    plot_geom:
        Whether to plot the associated geometry (if any).
    geom_kwargs:
        Keyword arguments to pass to the geometry plot of the associated geometry.
    backend:
        The backend to use to generate the figure.

    See also
    ----------
    scipy.ndimage.affine_transform : method used to orthogonalize the grid if needed.
    """

    # Create a grid with the wavefunction in it.
    i_eigenstate = get_eigenstate(eigenstate, i)
    geometry = eigenstate_geometry(eigenstate, geometry=geometry)

    tiled_geometry = tile_if_k(geometry=geometry, nsc=nsc, eigenstate=i_eigenstate)
    grid_nsc = get_grid_nsc(nsc=nsc, eigenstate=i_eigenstate)
    grid = project_wavefunction(
        eigenstate=i_eigenstate, grid_prec=grid_prec, grid=grid, geometry=tiled_geometry
    )

    # Grid processing
    axes = sanitize_axes(axes)

    grid_repr = get_grid_representation(grid, represent=represent)

    tiled_grid = tile_grid(grid_repr, nsc=grid_nsc)

    ort_grid = orthogonalize_grid_if_needed(tiled_grid, axes=axes, mode=boundary_mode)

    grid_axes = get_grid_axes(ort_grid, axes=axes)

    transformed_grid = apply_transforms(ort_grid, transforms)

    subbed_grid = sub_grid(
        transformed_grid, x_range=x_range, y_range=y_range, z_range=z_range
    )

    reduced_grid = reduce_grid(
        subbed_grid, reduce_method=reduce_method, keep_axes=grid_axes
    )

    interp_grid = interpolate_grid(reduced_grid, interp=interp)

    # Finally, here comes the plotting!
    grid_ds = grid_to_dataarray(
        interp_grid, axes=axes, grid_axes=grid_axes, nsc=grid_nsc
    )
    grid_plottings = draw_grid(
        data=grid_ds,
        isos=isos,
        colorscale=colorscale,
        crange=crange,
        cmid=cmid,
        smooth=smooth,
    )

    # Process the cell as well
    cell_plottings = cell_plot_actions(
        cell=grid,
        show_cell=show_cell,
        cell_style=cell_style,
        axes=axes,
    )

    # And maybe plot the strucuture
    geom_plottings = _get_structure_plottings(
        plot_geom=plot_geom,
        geometry=tiled_geometry,
        geom_kwargs=geom_kwargs,
        axes=axes,
        nsc=grid_nsc,
    )

    all_plottings = combined(
        grid_plottings, cell_plottings, geom_plottings, composite_method=None
    )

    return get_figure(backend=backend, plot_actions=all_plottings)


[docs]class GridPlot(Plot): function = staticmethod(grid_plot)
[docs]class WavefunctionPlot(GridPlot): function = staticmethod(wavefunction_plot)
# The following commented code is from the old viz module, where the GridPlot had a scan method. # It looks very nice, but probably should be reimplemented as a standalone function that plots a grid slice, # and then merge those grid slices to create a scan. # def scan(self, along, start=None, stop=None, step=None, num=None, breakpoints=None, mode="moving_slice", animation_kwargs=None, **kwargs): # """ # Returns an animation containing multiple frames scaning along an axis. # Parameters # ----------- # along: {"x", "y", "z"} # the axis along which the scan is performed. If not provided, it will scan along the axes that are not displayed. # start: float, optional # the starting value for the scan (in Angstrom). # Make sure this value is inside the range of the unit cell, otherwise it will fail. # stop: float, optional # the last value of the scan (in Angstrom). # Make sure this value is inside the range of the unit cell, otherwise it will fail. # step: float, optional # the distance between steps in Angstrom. # If not provided and `num` is also not provided, it will default to 1 Ang. # num: int , optional # the number of steps that you want the scan to consist of. # If `step` is passed, this argument is ignored. # Note that the grid is only stored once, so having a big number of steps is not that big of a deal. # breakpoints: array-like, optional # the discrete points of the scan. To be used if you don't want regular steps. # If the last step is exactly the length of the cell, it will be moved one dcell back to avoid errors. # Note that if this parameter is passed, both `step` and `num` are ignored. # mode: {"moving_slice", "as_is"}, optional # the type of scan you want to see. # "moving_slice" renders a volumetric scan where a slice moves through the grid. # "as_is" renders each part of the scan as an animation frame. # (therefore, "as_is" SUPPORTS SCANNING 1D, 2D AND 3D REPRESENTATIONS OF THE GRID, e.g. display the volume data for different ranges of z) # animation_kwargs: dict, optional # dictionary whose keys and values are directly passed to the animated method as kwargs and therefore # end up being passed to animation initialization. # **kwargs: # the rest of settings that you want to apply to overwrite the existing ones. # This settings apply to each plot and go directly to their initialization. # Returns # ------- # sisl.viz.plotly.Animation # An animation representation of the scan # """ # # Do some checks on the args provided # if sum(1 for arg in (step, num, breakpoints) if arg is not None) > 1: # raise ValueError(f"Only one of ('step', 'num', 'breakpoints') should be passed.") # axes = self.inputs['axes'] # if mode == "as_is" and set(axes) - set(["x", "y", "z"]): # raise ValueError("To perform a scan, the axes need to be cartesian. Please set the axes to a combination of 'x', 'y' and 'z'.") # if self.grid.lattice.is_cartesian(): # grid = self.grid # else: # transform_bc = kwargs.pop("transform_bc", self.get_setting("transform_bc")) # grid, transform_offset = self._transform_grid_cell( # self.grid, mode=transform_bc, output_shape=self.grid.shape, cval=np.nan # ) # kwargs["offset"] = transform_offset + kwargs.get("offset", self.get_setting("offset")) # # We get the key that needs to be animated (we will divide the full range in frames) # range_key = f"{along}_range" # along_i = {"x": 0, "y": 1, "z": 2}[along] # # Get the full range # if start is not None and stop is not None: # along_range = [start, stop] # else: # along_range = self.get_setting(range_key) # if along_range is None: # range_param = self.get_param(range_key) # along_range = [range_param[f"inputField.params.{lim}"] for lim in ["min", "max"]] # if start is not None: # along_range[0] = start # if stop is not None: # along_range[1] = stop # if breakpoints is None: # if step is None and num is None: # step = 1.0 # if step is None: # step = (along_range[1] - along_range[0]) / num # else: # num = (along_range[1] - along_range[0]) // step # # np.linspace will use the last point as a step (and we don't want it) # # therefore we will add an extra step # breakpoints = np.linspace(*along_range, int(num) + 1) # if breakpoints[-1] == grid.cell[along_i, along_i]: # breakpoints[-1] = grid.cell[along_i, along_i] - grid.dcell[along_i, along_i] # if mode == "moving_slice": # return self._moving_slice_scan(grid, along_i, breakpoints) # elif mode == "as_is": # return self._asis_scan(grid, range_key, breakpoints, animation_kwargs=animation_kwargs, **kwargs) # def _asis_scan(self, grid, range_key, breakpoints, animation_kwargs=None, **kwargs): # """ # Returns an animation containing multiple frames scaning along an axis. # Parameters # ----------- # range_key: {'x_range', 'y_range', 'z_range'} # the key of the setting that is to be animated through the scan. # breakpoints: array-like # the discrete points of the scan # animation_kwargs: dict, optional # dictionary whose keys and values are directly passed to the animated method as kwargs and therefore # end up being passed to animation initialization. # **kwargs: # the rest of settings that you want to apply to overwrite the existing ones. # This settings apply to each plot and go directly to their initialization. # Returns # ---------- # scan: sisl Animation # An animation representation of the scan # """ # # Generate the plot using self as a template so that plots don't need # # to read data, just process it and show it differently. # # (If each plot read the grid, the memory requirements would be HUGE) # scan = self.animated( # { # range_key: [[bp, breakpoints[i+1]] for i, bp in enumerate(breakpoints[:-1])] # }, # fixed={**{key: val for key, val in self.settings.items() if key != range_key}, **kwargs, "grid": grid}, # frame_names=[f'{bp:2f}' for bp in breakpoints], # **(animation_kwargs or {}) # ) # # Set all frames to the same colorscale, if it's a 2d representation # if len(self.get_setting("axes")) == 2: # cmin = 10**6; cmax = -10**6 # for scan_im in scan: # c = getattr(scan_im.data[0], "value", scan_im.data[0].z) # cmin = min(cmin, np.min(c)) # cmax = max(cmax, np.max(c)) # for scan_im in scan: # scan_im.update_settings(crange=[cmin, cmax]) # scan.get_figure() # scan.layout = self.layout # return scan # def _moving_slice_scan(self, grid, along_i, breakpoints): # import plotly.graph_objs as go # ax = along_i # displayed_axes = [i for i in range(3) if i != ax] # shape = np.array(grid.shape)[displayed_axes] # cmin = np.min(grid.grid) # cmax = np.max(grid.grid) # x_ax, y_ax = displayed_axes # x = np.linspace(0, grid.cell[x_ax, x_ax], grid.shape[x_ax]) # y = np.linspace(0, grid.cell[y_ax, y_ax], grid.shape[y_ax]) # fig = go.Figure(frames=[go.Frame(data=go.Surface( # x=x, y=y, # z=(bp * np.ones(shape)).T, # surfacecolor=np.squeeze(grid.cross_section(grid.index(bp, ax), ax).grid).T, # cmin=cmin, cmax=cmax, # ), # name=f'{bp:.2f}' # ) # for bp in breakpoints]) # # Add data to be displayed before animation starts # fig.add_traces(fig.frames[0].data) # def frame_args(duration): # return { # "frame": {"duration": duration}, # "mode": "immediate", # "fromcurrent": True, # "transition": {"duration": duration, "easing": "linear"}, # } # sliders = [ # { # "pad": {"b": 10, "t": 60}, # "len": 0.9, # "x": 0.1, # "y": 0, # "steps": [ # { # "args": [[f.name], frame_args(0)], # "label": str(k), # "method": "animate", # } # for k, f in enumerate(fig.frames) # ], # } # ] # def ax_title(ax): return f'{["X", "Y", "Z"][ax]} axis [Ang]' # # Layout # fig.update_layout( # title=f'Grid scan along {["X", "Y", "Z"][ax]} axis', # width=600, # height=600, # scene=dict( # xaxis=dict(title=ax_title(x_ax)), # yaxis=dict(title=ax_title(y_ax)), # zaxis=dict(autorange=True, title=ax_title(ax)), # aspectmode="data", # ), # updatemenus = [ # { # "buttons": [ # { # "args": [None, frame_args(50)], # "label": "▶", # play symbol # "method": "animate", # }, # { # "args": [[None], frame_args(0)], # "label": "◼", # pause symbol # "method": "animate", # }, # ], # "direction": "left", # "pad": {"r": 10, "t": 70}, # "type": "buttons", # "x": 0.1, # "y": 0, # } # ], # sliders=sliders # ) # # We need to add an invisible trace so that the z axis stays with the correct range # fig.add_trace({"type": "scatter3d", "mode": "markers", "marker_size": 0.001, "x": [0, 0], "y": [0, 0], "z": [0, grid.cell[ax, ax]]}) # return fig