Skip to content

Commit

Permalink
Remove text background and fix z-order in structure_viz (#139)
Browse files Browse the repository at this point in the history
* make background transparent

* update MP API entrance

* update structure by mp_id

* update matbench structure

* add mp_api to data-src dep

* update matbench fig

* tweak comments

* add zorder

* fix zorder for cell and make occlude default

* update figures

* fix occlusion order

* update figures

* drop bbox completely

* breaking: remove `site_labels_bbox`

* remove `site_labels_bbox` from unit test

* move ExperimentalWarning to pymatviz/utils.py

* restore oxi states in matbench-phonons-structures-2d.svg

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
DanielYang59 and janosh committed May 10, 2024
1 parent 6ecf68a commit 39fd71e
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 75 deletions.
2 changes: 1 addition & 1 deletion assets/matbench-phonons-structures-2d.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion assets/struct-2d-mp-12712-Hf9Zr9Pd24-disordered.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 9 additions & 3 deletions examples/_generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from matminer.datasets import load_dataset
from monty.io import zopen
from monty.json import MontyDecoder
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element
from pymatgen.ext.matproj import MPRester
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine as PhononBands
from pymatgen.phonon.dos import PhononDos
from tqdm import tqdm
Expand Down Expand Up @@ -61,6 +62,8 @@
px.defaults.template = "pymatviz_white"
pio.templates.default = "pymatviz_white"

struct: Structure # for type hinting

# Random classification data
np.random.seed(42)
rand_clf_size = 100
Expand Down Expand Up @@ -338,8 +341,11 @@
title = f"{len(axs.flat)} Matbench phonon structures"
fig.suptitle(title, fontweight="bold", fontsize=20)

for row, ax in zip(df_phonons.itertuples(), axs.flat):
idx, struct, *_, spg_num = row
for idx, (row, ax) in enumerate(zip(df_phonons.itertuples(), axs.flat), start=1):
struct = row.structure
spg_num = struct.get_space_group_info()[1]
struct.add_oxidation_state_by_guess()

plot_structure_2d(
struct,
ax=ax,
Expand Down
132 changes: 71 additions & 61 deletions pymatviz/structure_viz.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,34 @@
"""2D plots of pymatgen structures with matplotlib."""
"""2D plots of pymatgen structures with matplotlib.
plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.
"""

from __future__ import annotations

import math
import warnings
from itertools import product
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import PathPatch, Wedge
from matplotlib.path import Path
from pymatgen.analysis.local_env import CrystalNN, NearNeighbors
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from pymatviz.utils import covalent_radii, jmol_colors
from pymatviz.utils import ExperimentalWarning, covalent_radii, jmol_colors


if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, Literal

from numpy.typing import ArrayLike
from pymatgen.core import Structure


class ExperimentalWarning(Warning):
"""Used for experimental show_bonds feature."""


warnings.simplefilter("once", ExperimentalWarning)


# plot_structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were
# inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib


def _angles_to_rotation_matrix(
angles: str, rotation: ArrayLike | None = None
) -> ArrayLike:
Expand All @@ -52,8 +47,10 @@ def _angles_to_rotation_matrix(
"""
if rotation is None:
rotation = np.eye(3)

# Return initial rotation matrix if no angles
if not angles:
return rotation.copy() # return initial rotation matrix if no angles
return rotation.copy()

for angle in angles.split(","):
radians = math.radians(float(angle[:-1]))
Expand Down Expand Up @@ -82,28 +79,27 @@ def unit_cell_to_lines(cell: ArrayLike) -> tuple[ArrayLike, ArrayLike, ArrayLike
- z-indices that sort plot elements into out-of-plane layers
- lines used to plot the unit cell
"""
n_lines = 0
n_lines = n1 = 0
segments = []
for c in range(3):
norm = math.sqrt(sum(cell[c] ** 2))
for idx in range(3):
norm = math.sqrt(sum(cell[idx] ** 2))
segment = max(2, int(norm / 0.3))
segments.append(segment)
n_lines += 4 * segment

lines = np.empty((n_lines, 3))
z_indices = np.empty(n_lines, int)
z_indices = np.empty(n_lines, dtype=int)
unit_cell_lines = np.zeros((3, 3))

n1 = 0
for c in range(3):
segment = segments[c]
dd = cell[c] / (4 * segment - 2)
unit_cell_lines[c] = dd
for idx in range(3):
segment = segments[idx]
dd = cell[idx] / (4 * segment - 2)
unit_cell_lines[idx] = dd
P = np.arange(1, 4 * segment + 1, 4)[:, None] * dd
z_indices[n1:] = c
z_indices[n1:] = idx
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
n2 = n1 + segment
lines[n1:n2] = P + i * cell[c - 2] + j * cell[c - 1]
lines[n1:n2] = P + i * cell[idx - 2] + j * cell[idx - 1]
n1 = n2

return lines, z_indices, unit_cell_lines
Expand All @@ -122,13 +118,12 @@ def plot_structure_2d(
| Literal["symbol", "species"]
| dict[str, str | float]
| Sequence[str | float] = True,
site_labels_bbox: dict[str, Any] | None = None,
label_kwargs: dict[str, Any] | None = None,
bond_kwargs: dict[str, Any] | None = None,
standardize_struct: bool | None = None,
axis: bool | str = "off",
) -> plt.Axes:
"""Plot pymatgen structures in 2d with matplotlib.
"""Plot pymatgen structures in 2D with matplotlib.
Inspired by ASE's ase.visualize.plot.plot_atoms()
https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib
Expand All @@ -137,7 +132,7 @@ def plot_structure_2d(
For example, these two snippets should give very similar output:
```py
```python
from pymatgen.ext.matproj import MPRester
mp_19017 = MPRester().get_structure_by_material_id("mp-19017")
Expand Down Expand Up @@ -182,10 +177,10 @@ def plot_structure_2d(
colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6).
Defaults to JMol colors (https://jmol.sourceforge.net/jscolors).
scale (float, optional): Scaling of the plotted atoms and lines. Defaults to 1.
show_unit_cell (bool, optional): Whether to draw unit cell. Defaults to True.
show_bonds (bool | NearNeighbors, optional): Whether to draw bonds. If True, use
show_unit_cell (bool, optional): Whether to plot unit cell. Defaults to True.
show_bonds (bool | NearNeighbors, optional): Whether to plot bonds. If True, use
pymatgen.analysis.local_env.CrystalNN to infer the structure's connectivity.
If False, don't draw bonds. If a subclass of
If False, don't plot bonds. If a subclass of
pymatgen.analysis.local_env.NearNeighbors, use that to determine
connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN,
CovalentBondNN, dtc. Defaults to True.
Expand All @@ -197,12 +192,10 @@ def plot_structure_2d(
number of sites in the crystal. If a string, must be "symbol" or
"species". "symbol" hides the oxidation state, "species" shows it
(equivalent to True). Defaults to True.
site_labels_bbox (dict, optional): Keyword arguments for matplotlib.text.Text
bbox like {"facecolor": "white", "alpha": 0.5}. Defaults to None.
label_kwargs (dict, optional): Keyword arguments for matplotlib.text.Text like
{"fontsize": 14}. Defaults to None.
bond_kwargs (dict, optional): Keyword arguments for the matplotlib.path.Path
class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
class used to plot chemical bonds. Allowed are edgecolor, facecolor, color,
linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle.
Defaults to None.
standardize_struct (bool, optional): Whether to standardize the structure using
Expand All @@ -228,7 +221,7 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f" the number of sites in the crystal ({len(struct)=})"
)

# default behavior in case of no user input is to standardize if any fractional
# Default behavior in case of no user input: standardize if any fractional
# coordinates are negative
has_sites_outside_unit_cell = any(any(site.frac_coords < 0) for site in struct)
if standardize_struct is False and has_sites_outside_unit_cell:
Expand All @@ -240,9 +233,9 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
elif standardize_struct is None:
standardize_struct = has_sites_outside_unit_cell
if standardize_struct:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

struct = SpacegroupAnalyzer(struct).get_conventional_standard_structure()

# Get default colors
if colors is None:
colors = jmol_colors

Expand All @@ -257,15 +250,14 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
else:
# atomic_radii is assumed to be a map from element symbols to atomic radii
# make sure all present elements are assigned a radius
missing = set(elements_at_sites) - set(atomic_radii)
if missing:
if missing := set(elements_at_sites) - set(atomic_radii):
raise ValueError(f"atomic_radii is missing keys: {missing}")

radii_at_sites = np.array(
[atomic_radii[el] for el in elements_at_sites] # type: ignore[index]
)

n_atoms = len(struct)
# Generate lines for unit cell
rotation_matrix = _angles_to_rotation_matrix(rotation)
unit_cell = struct.lattice.matrix

Expand All @@ -280,28 +272,35 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
unit_cell_lines = None
cell_vertices = None

# Zip atoms and unit cell lines together
n_atoms = len(struct)
n_lines = len(lines)

positions = np.empty((n_atoms + n_lines, 3))
site_coords = np.array([site.coords for site in struct])
positions[:n_atoms] = site_coords
positions[n_atoms:] = lines

# determine which lines should be hidden behind other objects
# Determine which unit cell line should be hidden behind other objects
for idx in range(n_lines):
this_layer = unit_cell_lines[z_indices[idx]]

occluded_top = ((site_coords - lines[idx] + this_layer) ** 2).sum(
1
) < radii_at_sites**2

occluded_bottom = ((site_coords - lines[idx] - this_layer) ** 2).sum(
1
) < radii_at_sites**2

if any(occluded_top & occluded_bottom):
z_indices[idx] = -1

# Apply rotation matrix
positions = np.dot(positions, rotation_matrix)
rotated_site_coords = positions[:n_atoms]

# Normalize wedge positions
min_coords = (rotated_site_coords - radii_at_sites[:, None]).min(0)
max_coords = (rotated_site_coords + radii_at_sites[:, None]).max(0)

Expand All @@ -316,19 +315,23 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
positions *= scale
positions -= offset

# Rotate and scale unit cell lines
if n_lines > 0:
unit_cell_lines = np.dot(unit_cell_lines, rotation_matrix)[:, :2] * scale

special_site_labels = ("symbol", "species")
# sort positions by 3rd dim so we draw from back to front in z-axis (out-of-plane)
# Sort positions by 3rd dim to plot from back to front along z-axis (out-of-plane)
for idx in positions[:, 2].argsort():
xy = positions[idx, :2]
start = 0
zorder = positions[idx][2]

if idx < n_atoms:
# loop over all species on a site (usually just 1 for ordered sites)
for specie, occupancy in struct[idx].species.items():
# strip oxidation state from element symbol (e.g. Ta5+ to Ta)
elem_symbol = specie.symbol
# Loop over all species on a site (usually just 1 for ordered sites)
for species, occupancy in struct[idx].species.items():
# Strip oxidation state from element symbol (e.g. Ta5+ to Ta)
elem_symbol = species.symbol

radius = atomic_radii[elem_symbol] * scale # type: ignore[index]
face_color = colors[elem_symbol]
wedge = Wedge(
Expand All @@ -338,23 +341,25 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
360 * (start + occupancy),
facecolor=face_color,
edgecolor="black",
zorder=zorder,
)
ax.add_patch(wedge)

# Generate labels
if site_labels == "symbol":
txt = elem_symbol
elif site_labels in ("species", True):
txt = specie
txt = species
elif site_labels is False:
txt = ""
elif isinstance(site_labels, dict):
# try element incl. oxidation state as dict key first (e.g. Na+),
# Try element incl. oxidation state as dict key first (e.g. Na+),
# then just element as fallback
txt = site_labels.get(
repr(specie), site_labels.get(elem_symbol, "")
repr(species), site_labels.get(elem_symbol, "")
)
if txt in special_site_labels:
txt = specie if txt == "species" else elem_symbol
txt = species if txt == "species" else elem_symbol
elif isinstance(site_labels, (list, tuple)):
txt = site_labels[idx] # idx runs from 0 to n_atoms
else:
Expand All @@ -363,29 +368,33 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
f"{', '.join(special_site_labels)}, dict, list)"
)

# Add labels
if site_labels:
# place element symbol half way along outer wedge edge for
# Place element symbol half way along outer wedge edge for
# disordered sites
half_way = 2 * np.pi * (start + occupancy / 2)
direction = np.array([math.cos(half_way), math.sin(half_way)])
text_offset = (
(0.5 * radius) * direction if occupancy < 1 else (0, 0)
)

bbox = dict(facecolor="none", edgecolor="none", pad=1)
bbox.update(site_labels_bbox or {})
txt_kwds = dict(
ha="center", va="center", bbox=bbox, **(label_kwargs or {})
ha="center",
va="center",
zorder=zorder,
**(label_kwargs or {}),
)
ax.text(*(xy + text_offset), txt, **txt_kwds)

start += occupancy
else: # draw unit cell

# Plot unit cell
else:
cell_idx = idx - n_atoms
# only draw line if not obstructed by an atom
# Only plot lines not obstructed by an atom
if z_indices[cell_idx] != -1:
hxy = unit_cell_lines[z_indices[cell_idx]]
path = PathPatch(Path((xy + hxy, xy - hxy)))
path = PathPatch(Path((xy + hxy, xy - hxy)), zorder=zorder)
ax.add_patch(path)

if show_bonds:
Expand All @@ -404,16 +413,17 @@ class used to draw chemical bonds. Allowed are edgecolor, facecolor, color,
)

# If structure doesn't have any oxidation states yet, guess them from chemical
# composition. Helps CrystalNN and other strategies to estimate better bond
# connectivity. Uses getattr on site.specie since it's often a pymatgen Element
# composition. Use CrystalNN and other strategies to better estimate bond
# connectivity. Use getattr on site.specie since it's often a pymatgen Element
# which has no oxi_state
if not any(
hasattr(getattr(site, "specie", None), "oxi_state") for site in struct
):
try:
struct.add_oxidation_state_by_guess()
except ValueError: # fails for disordered structures
"Charge balance analysis requires integer values in Composition"
# Charge balance analysis requires integer values in Composition
pass

structure_graph = neighbor_strategy_cls().get_bonded_structure(struct)

Expand Down
Loading

0 comments on commit 39fd71e

Please sign in to comment.