from typing import Callable, Literal, Optional, Sequence, Tuple, TypeVar, Union
import numpy as np
from sisl import BrillouinZone, Geometry
from sisl.typing import AtomsArgument
from sisl.viz.figure import Figure, get_figure
from sisl.viz.plotters import plot_actions as plot_actions
from sisl.viz.types import AtomArrowSpec, AtomsStyleSpec, Axes, StyleSpec
from ..plot import Plot
from ..plotters.cell import cell_plot_actions, get_ndim, get_z
from ..plotters.xarray import draw_xarray_xy
from ..processors.axes import sanitize_axes
from ..processors.coords import project_to_axes
from ..processors.geometry import (
add_xyz_to_bonds_dataset,
add_xyz_to_dataset,
bonds_to_lines,
find_all_bonds,
get_sites_units,
parse_atoms_style,
sanitize_arrows,
sanitize_atoms,
sanitize_bonds_selection,
sites_obj_to_geometry,
stack_sc_data,
style_bonds,
tile_data_sc,
)
from ..processors.logic import matches, switch
from ..processors.xarray import scale_variable, select
def _get_atom_mode(drawing_mode, ndim):
if drawing_mode is None:
if ndim == 3:
return "balls"
else:
return "scatter"
return drawing_mode
def _get_arrow_plottings(atoms_data, arrows, nsc=[1, 1, 1]):
reps = np.prod(nsc)
actions = []
atoms_data = atoms_data.unstack("sc_atom")
for arrows_spec in arrows:
filtered = atoms_data.sel(atom=arrows_spec["atoms"])
dxy = arrows_spec["data"][arrows_spec["atoms"]]
dxy = np.tile(np.ravel(dxy), reps).reshape(-1, arrows_spec["data"].shape[-1])
# If it is a 1D plot, make sure that the arrows have two coordinates, being 0 the second one.
if dxy.shape[-1] == 1:
dxy = np.array([dxy[:, 0], np.zeros_like(dxy[:, 0])]).T
kwargs = {}
kwargs["line"] = {
"color": arrows_spec["color"],
"width": arrows_spec["width"],
"opacity": arrows_spec.get("opacity", 1),
}
kwargs["name"] = arrows_spec["name"]
kwargs["arrowhead_scale"] = arrows_spec["arrowhead_scale"]
kwargs["arrowhead_angle"] = arrows_spec["arrowhead_angle"]
kwargs["annotate"] = arrows_spec.get("annotate", False)
kwargs["scale"] = arrows_spec["scale"]
if dxy.shape[-1] < 3:
action = plot_actions.draw_arrows(
x=np.ravel(filtered.x), y=np.ravel(filtered.y), dxy=dxy, **kwargs
)
else:
action = plot_actions.draw_arrows_3D(
x=np.ravel(filtered.x),
y=np.ravel(filtered.y),
z=np.ravel(filtered.z),
dxyz=dxy,
**kwargs,
)
actions.append(action)
return actions
def _sanitize_scale(
scale: float, ndim: int, ndim_scale: Tuple[float, float, float] = (16, 16, 1)
):
return ndim_scale[ndim - 1] * scale
def geometry_plot(
geometry: Geometry,
axes: Axes = ["x", "y", "z"],
atoms: AtomsArgument = None,
atoms_style: Sequence[AtomsStyleSpec] = [],
atoms_scale: float = 1.0,
atoms_colorscale: Optional[str] = None,
drawing_mode: Literal["scatter", "balls", None] = None,
bind_bonds_to_ats: bool = True,
points_per_bond: int = 20,
bonds_style: StyleSpec = {},
bonds_scale: float = 1.0,
bonds_colorscale: Optional[str] = None,
show_atoms: bool = True,
show_bonds: bool = True,
show_cell: Literal["box", "axes", False] = "box",
cell_style: StyleSpec = {},
nsc: Tuple[int, int, int] = (1, 1, 1),
atoms_ndim_scale: Tuple[float, float, float] = (16, 16, 1),
bonds_ndim_scale: Tuple[float, float, float] = (1, 1, 10),
dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None,
arrows: Sequence[AtomArrowSpec] = (),
backend="plotly",
) -> Figure:
"""Plots a geometry structure, with plentiful of customization options.
Parameters
----------
geometry:
The geometry to plot.
axes:
The axes to project the geometry to.
atoms:
The atoms to plot. If None, all atoms are plotted.
atoms_style:
List of style specifications for the atoms. See the showcase notebooks for examples.
atoms_scale:
Scaling factor for the size of all atoms.
atoms_colorscale:
Colorscale to use for the atoms in case the color attribute is an array of values.
If None, the default colorscale is used for each backend.
drawing_mode:
The method used to draw the atoms.
bind_bonds_to_ats:
Whether to display only bonds between atoms that are being displayed.
points_per_bond:
When the points are drawn using points instead of lines (e.g. in some frameworks
to draw multicolor bonds), the number of points used per bond.
bonds_style:
Style specification for the bonds. See the showcase notebooks for examples.
bonds_scale:
Scaling factor for the width of all bonds.
bonds_colorscale:
Colorscale to use for the bonds in case the color attribute is an array of values.
If None, the default colorscale is used for each backend.
show_atoms:
Whether to display the atoms.
show_bonds:
Whether to display the bonds.
show_cell:
Mode to display the cell. If False, the cell is not displayed.
cell_style:
Style specification for the cell. See the showcase notebooks for examples.
nsc:
Number of unit cells to display in each direction.
atoms_ndim_scale:
Scaling factor for the size of the atoms for different dimensionalities (1D, 2D, 3D).
bonds_ndim_scale:
Scaling factor for the width of the bonds for different dimensionalities (1D, 2D, 3D).
dataaxis_1d:
Only meaningful for 1D plots. The data to plot on the Y axis.
arrows:
List of arrow specifications to display. See the showcase notebooks for examples.
backend:
The backend to use to generate the figure.
"""
# INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE
# SO PARSING IS DELEGATED TO NODES.
axes = sanitize_axes(axes)
sanitized_atoms = sanitize_atoms(geometry, atoms=atoms)
ndim = get_ndim(axes)
z = get_z(ndim)
# Atoms and bonds are processed in parallel paths, which means that one is able
# to update without requiring the other. This means: 1) Faster updates if only one
# of them needs to update; 2) It should be possible to run each path in a different
# thread/process, potentially increasing speed.
parsed_atom_style = parse_atoms_style(geometry, atoms_style=atoms_style)
atoms_dataset = add_xyz_to_dataset(parsed_atom_style)
atoms_filter = switch(show_atoms, sanitized_atoms, [])
filtered_atoms = select(atoms_dataset, "atom", atoms_filter)
tiled_atoms = tile_data_sc(filtered_atoms, nsc=nsc)
sc_atoms = stack_sc_data(tiled_atoms, newname="sc_atom", dims=["atom"])
projected_atoms = project_to_axes(
sc_atoms, axes=axes, sort_by_depth=True, dataaxis_1d=dataaxis_1d
)
atoms_scale = _sanitize_scale(atoms_scale, ndim, atoms_ndim_scale)
final_atoms = scale_variable(projected_atoms, "size", scale=atoms_scale)
atom_mode = _get_atom_mode(drawing_mode, ndim)
atom_plottings = draw_xarray_xy(
data=final_atoms,
x="x",
y="y",
z=z,
width="size",
what=atom_mode,
colorscale=atoms_colorscale,
set_axequal=True,
name="Atoms",
)
# Here we start to process bonds
bonds = find_all_bonds(geometry)
show_bonds = matches(ndim, 1, False, show_bonds)
styled_bonds = style_bonds(bonds, bonds_style)
bonds_dataset = add_xyz_to_bonds_dataset(styled_bonds)
bonds_filter = sanitize_bonds_selection(
bonds_dataset, sanitized_atoms, bind_bonds_to_ats, show_bonds
)
filtered_bonds = select(bonds_dataset, "bond_index", bonds_filter)
tiled_bonds = tile_data_sc(filtered_bonds, nsc=nsc)
projected_bonds = project_to_axes(tiled_bonds, axes=axes)
bond_lines = bonds_to_lines(projected_bonds, points_per_bond=points_per_bond)
bonds_scale = _sanitize_scale(bonds_scale, ndim, bonds_ndim_scale)
final_bonds = scale_variable(bond_lines, "width", scale=bonds_scale)
bond_plottings = draw_xarray_xy(
data=final_bonds,
x="x",
y="y",
z=z,
set_axequal=True,
name="Bonds",
colorscale=bonds_colorscale,
)
# And now the cell
show_cell = matches(ndim, 1, False, show_cell)
cell_plottings = cell_plot_actions(
cell=geometry,
show_cell=show_cell,
cell_style=cell_style,
axes=axes,
dataaxis_1d=dataaxis_1d,
)
# And the arrows
arrow_data = sanitize_arrows(
geometry, arrows, atoms=sanitized_atoms, ndim=ndim, axes=axes
)
arrow_plottings = _get_arrow_plottings(projected_atoms, arrow_data, nsc=nsc)
all_actions = plot_actions.combined(
bond_plottings,
atom_plottings,
cell_plottings,
arrow_plottings,
composite_method=None,
)
return get_figure(backend=backend, plot_actions=all_actions)
[docs]class GeometryPlot(Plot):
function = staticmethod(geometry_plot)
@property
def geometry(self):
return self.nodes.inputs["geometry"]._output
_T = TypeVar("_T", list, tuple, dict)
def _sites_specs_to_atoms_specs(sites_specs: _T) -> _T:
if isinstance(sites_specs, dict):
if "sites" in sites_specs:
sites_specs = sites_specs.copy()
sites_specs["atoms"] = sites_specs.pop("sites")
return sites_specs
else:
return type(sites_specs)(
_sites_specs_to_atoms_specs(style_spec) for style_spec in sites_specs
)
def sites_plot(
sites_obj: BrillouinZone,
axes: Axes = ["x", "y", "z"],
sites: AtomsArgument = None,
sites_style: Sequence[AtomsStyleSpec] = [],
sites_scale: float = 1.0,
sites_name: str = "Sites",
sites_colorscale: Optional[str] = None,
drawing_mode: Literal["scatter", "balls", "line", None] = None,
show_cell: Literal["box", "axes", False] = False,
cell_style: StyleSpec = {},
nsc: Tuple[int, int, int] = (1, 1, 1),
sites_ndim_scale: Tuple[float, float, float] = (1, 1, 1),
dataaxis_1d: Optional[Union[np.ndarray, Callable]] = None,
arrows: Sequence[AtomArrowSpec] = (),
backend="plotly",
) -> Figure:
"""Plots sites from an object that can be parsed into a geometry.
The only differences between this plot and a geometry plot is the naming of the inputs
and the fact that there are no options to plot bonds.
Parameters
----------
sites_obj:
The object to be converted to sites.
axes:
The axes to project the sites to.
sites:
The sites to plot. If None, all sites are plotted.
sites_style:
List of style specifications for the sites. See the showcase notebooks for examples.
sites_scale:
Scaling factor for the size of all sites.
sites_name:
Name to give to the trace that draws the sites.
sites_colorscale:
Colorscale to use for the sites in case the color attribute is an array of values.
If None, the default colorscale is used for each backend.
drawing_mode:
The method used to draw the sites.
show_cell:
Mode to display the reciprocal cell. If False, the cell is not displayed.
cell_style:
Style specification for the reciprocal cell. See the showcase notebooks for examples.
nsc:
Number of unit cells to display in each direction.
sites_ndim_scale:
Scaling factor for the size of the sites for different dimensionalities (1D, 2D, 3D).
dataaxis_1d:
Only meaningful for 1D plots. The data to plot on the Y axis.
arrows:
List of arrow specifications to display. See the showcase notebooks for examples.
backend:
The backend to use to generate the figure.
"""
# INPUTS ARE NOT GETTING PARSED BECAUSE WORKFLOWS RUN GET ON FINAL NODE
# SO PARSING IS DELEGATED TO NODES.
axes = sanitize_axes(axes)
fake_geometry = sites_obj_to_geometry(sites_obj)
sanitized_sites = sanitize_atoms(fake_geometry, atoms=sites)
ndim = get_ndim(axes)
z = get_z(ndim)
# Process sites
atoms_style = _sites_specs_to_atoms_specs(sites_style)
parsed_sites_style = parse_atoms_style(fake_geometry, atoms_style=atoms_style)
sites_dataset = add_xyz_to_dataset(parsed_sites_style)
filtered_sites = select(sites_dataset, "atom", sanitized_sites)
tiled_sites = tile_data_sc(filtered_sites, nsc=nsc)
sc_sites = stack_sc_data(tiled_sites, newname="sc_atom", dims=["atom"])
sites_units = get_sites_units(sites_obj)
projected_sites = project_to_axes(
sc_sites,
axes=axes,
sort_by_depth=True,
dataaxis_1d=dataaxis_1d,
cartesian_units=sites_units,
)
sites_scale = _sanitize_scale(sites_scale, ndim, sites_ndim_scale)
final_sites = scale_variable(projected_sites, "size", scale=sites_scale)
sites_mode = _get_atom_mode(drawing_mode, ndim)
site_plottings = draw_xarray_xy(
data=final_sites,
x="x",
y="y",
z=z,
width="size",
what=sites_mode,
colorscale=sites_colorscale,
set_axequal=True,
name=sites_name,
)
# And now the cell
show_cell = matches(ndim, 1, False, show_cell)
cell_plottings = cell_plot_actions(
cell=fake_geometry,
show_cell=show_cell,
cell_style=cell_style,
axes=axes,
dataaxis_1d=dataaxis_1d,
)
# And the arrows
atom_arrows = _sites_specs_to_atoms_specs(arrows)
arrow_data = sanitize_arrows(
fake_geometry, atom_arrows, atoms=sanitized_sites, ndim=ndim, axes=axes
)
arrow_plottings = _get_arrow_plottings(projected_sites, arrow_data, nsc=nsc)
all_actions = plot_actions.combined(
site_plottings, cell_plottings, arrow_plottings, composite_method=None
)
return get_figure(backend=backend, plot_actions=all_actions)
[docs]class SitesPlot(Plot):
function = staticmethod(sites_plot)