API
module bar
Bar plots.
Global Variables
- TYPE_CHECKING
- PLOTLY
function spacegroup_bar
spacegroup_bar(
data: 'Sequence[int | str | Structure] | Series',
show_counts: 'bool' = True,
xticks: "Literal['all', 'crys_sys_edges'] | int" = 20,
show_empty_bins: 'bool' = False,
ax: 'Axes | None' = None,
backend: 'Backend' = 'plotly',
text_kwargs: 'dict[str, Any] | None' = None,
log: 'bool' = False,
**kwargs: 'Any'
) → Axes | Figure
Plot a histogram of spacegroups shaded by crystal system.
Args:
data
(list[int | str | Structure] | pd.Series): Space group strings or numbers (from 1 - 230) or pymatgen structures.show_counts
(bool, optional): Whether to count the number of items in each crystal system. Defaults to True.xticks
("all" | "crys_sys_edges" | int, optional): Where to add x-ticks. An integer will add ticks below that number of tallest bars. Defaults to 20. "all" will show below all bars, "crys_sys_edges" only at the edge from one crystal system to another.show_empty_bins
(bool, optional): Whether to include a 0-height bar for missing space groups missing from the data. Currently only implemented for numbers, not symbols. Defaults to False.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.backend
("matplotlib" | "plotly", optional): Which backend to use for plotting. Defaults to "plotly".text_kwargs
(dict, optional): Keyword arguments passed to matplotlib.Axes.text(). Defaults to None. Has no effect if backend is "plotly".log
(bool, optional): Whether to log scale the y-axis. Defaults to False.kwargs
: Keywords passed to pd.Series.plot.bar() or plotly.express.bar().
Returns:
plt.Axes | go.Figure
: matplotlib Axes or plotly Figure depending on backend.
module colors
Colors used in pymatviz.
Global Variables
- TYPE_CHECKING
- ELEM_TYPE_COLORS
- ELEM_COLORS_JMOL
- ELEM_COLORS_VESTA
module coordination.helpers
Helper functions for calculating coordination numbers.
Global Variables
- TYPE_CHECKING
function create_hover_text
create_hover_text(
struct_key: 'str',
elem_symbol: 'str',
cn: 'int',
count: 'int',
hover_data: 'dict[str, str]',
data: 'dict[str, Any]',
is_single_structure: 'bool'
) → str
Create hover text for a single bar in the histogram.
function normalize_get_neighbors
normalize_get_neighbors(
strategy: 'float | NearNeighbors | type[NearNeighbors]'
) → Callable[[PeriodicSite, Structure], list[dict[str, Any]]]
Normalize get_neighbors function.
function calculate_average_cn
calculate_average_cn(
structure: 'Structure',
element: 'str',
get_neighbors: 'Callable[[PeriodicSite, Structure], list[dict[str, Any]]]'
) → float
Calculate the average coordination number for a given element in a structure.
function coordination_nums_in_structure
coordination_nums_in_structure(
structure: 'Structure',
strategy: 'float | NearNeighbors | type[NearNeighbors]' = 3.0,
group_by: "Literal['element', 'specie', 'site']" = 'element'
) → dict[str, list[int]]
Get coordination numbers (CN) for each element in a structure.
Args:
structure
: A pymatgen Structure objectstrategy
: Neighbor-finding strategy. Can be one of: - float: Cutoff distance for neighbor search in Angstroms - NearNeighbors: An instance of a NearNeighbors subclass - Type[NearNeighbors]: A NearNeighbors subclass (will be instantiated) Defaults to 3.0 (Angstroms cutoff)group_by
: How to group the coordination numbers. Can be one of: - "element": Group by element symbol - "site": Group by site - "specie": Group by specie
Returns:
dict[str, list[int]]
: Map of element symbols to lists of coordination numbers.E.g. {"Si"
: [4, 4, 4], "O": [2, 2, 2, 2, 2, 2]} for SiO2. Each number represents the CN of one atom of that element.
Example:
from pymatgen.core import Structure
>>> structure = Structure.from_file("SiO2.cif")
>>> cns = coordination_nums_in_structure(structure)
>>> print(cns)
{"Si": [4, 4, 4], "O": [2, 2, 2, 2, 2, 2]}
class CnSplitMode
How to split the coordination number histogram into subplots.
module coordination
Coordination number analysis and visualization.
module coordination.plotly
Plotly plots of coordination numbers distributions.
Global Variables
- TYPE_CHECKING
- ELEM_COLORS_JMOL
- ELEM_COLORS_VESTA
function coordination_hist
coordination_hist(
structures: 'Structure | dict[str, Structure] | Sequence[Structure]',
strategy: 'float | NearNeighbors | type[NearNeighbors]' = 3.0,
split_mode: 'CnSplitMode | str' = CnSplitMode.by_element,
bar_mode: "Literal['group', 'stack']" = 'stack',
hover_data: 'Sequence[str] | dict[str, str] | None' = None,
element_color_scheme: 'ElemColorScheme | dict[str, str]' = Jmol,
annotate_bars: 'bool | dict[str, Any]' = False,
bar_kwargs: 'dict[str, Any] | None' = None
) → Figure
Create a plotly histogram of coordination numbers for given structure(s).
Args:
structures
: A single structure or a dictionary or sequence of structures.strategy
: Neighbor-finding strategy. Can be one of: - float: Cutoff distance for neighbor search in Angstroms. - NearNeighbors: An instance of a NearNeighbors subclass. - Type[NearNeighbors]: A NearNeighbors subclass (will be instantiated). Defaults to 3.0 (Angstroms cutoff).split_mode
: How to split the data into subplots or color groups."none"
: Single plot with all data. All elements of all structures (if multiple were passed) will be shown in the same plot."by element"
: Split into subplots by element. Matching colors across subplots for different elements indicate those elements belong to the same structure."by structure"
: Split into subplots by structure, i.e. each structure gets its own subplot with coordination numbers for all sites plotted in the same color."by structure and element"
: Like "by structure", each structure gets its own subplot, but elements are colored differently within each structure.bar_mode
: How to arrange bars at the same coordination number."group"
: Bars are stacked and grouped side by side."stack"
: Bars are stacked on top of each other.hover_data
: Sequence of keys or dict mapping keys to pretty labels for additional data to be shown in the hover tooltip. The keys must exist in the site properties or properties dict of the structure.element_color_scheme
: Color scheme for elements. Can be "jmol", "vesta", or a custom dict.annotate_bars
: If True, annotate bars with element symbols when split_mode is 'by_element' or 'by_structure_and_element'. If a dict, used as keywordsfor bar annotations, e.g. {"font_size"
: 12, "font_color": "red"}.bar_kwargs
: Dictionary of keyword arguments to customize bar appearance. These will be passed to go.Bar().
Returns: A plotly Figure object containing the histogram.
function coordination_vs_cutoff_line
coordination_vs_cutoff_line(
structures: 'Structure | dict[str, Structure] | Sequence[Structure]',
strategy: 'tuple[float, float] | NearNeighbors | type[NearNeighbors]' = (1, 5),
num_points: 'int' = 50,
element_color_scheme: 'ElemColorScheme | dict[str, str]' = Jmol,
subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure
Create a plotly line plot of cumulative coordination numbers vs cutoff distance.
Args:
structures
: A single structure or a dictionary or sequence of structures.strategy
: Neighbor-finding strategy. Can be one of: - float: Single cutoff distance for neighbor search in Angstroms. - tuple[float, float]: (min_cutoff, max_cutoff) range in Angstroms. - NearNeighbors: An instance of a NearNeighbors subclass. - Type[NearNeighbors]: A NearNeighbors subclass (will be instantiated). Defaults to (1, 5) Angstrom range.num_points
: Number of points to calculate between min and max cutoff.element_color_scheme
: Color scheme for elements. Can be "jmol", "vesta", or a custom dict.subplot_kwargs
: Additional keyword arguments to pass to make_subplots().
Returns: A plotly Figure object containing the line plot.
module cumulative
Plot the cumulative distribution of residuals and absolute errors.
Global Variables
- TYPE_CHECKING
function cumulative_residual
cumulative_residual(
res: 'ArrayLike',
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Plot the empirical cumulative distribution for the residuals (y - mu).
Args:
res
(array): Residuals between y_true and y_pred, i.e. targets - model predictions.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.**kwargs
: Additional keyword arguments passed to ax.fill_between().
Returns:
plt.Axes
: matplotlib Axes object
function cumulative_error
cumulative_error(
abs_err: 'ArrayLike',
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Plot the empirical cumulative distribution of abs(y_true - y_pred).
Args:
abs_err
(array): Absolute error between y_true and y_pred, i.e. abs(targets - model predictions).ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.**kwargs
: Additional keyword arguments passed to ax.plot().
Returns:
plt.Axes
: matplotlib Axes object
module data
Generate dummy data for testing and prototyping.
function regression
regression(
n_samples: int = 500,
true_mean: float = 5,
true_std: float = 4,
pred_slope: float = 1.2,
pred_intercept: float = -2,
seed: int = 0
) → RegressionData
Generate dummy regression data for testing and prototyping.
This function creates synthetic data to simulate a regression task:
y_true
: Sampled from a normal distribution with mean 5 and standard deviation 4.y_pred
: Linearly related toy_true
with a slope of 1.2 and additional Gaussian noise.y_std
: Residuals scaled by random noise, representing variability in predictions.
Parameters:
n_samples
(int): Number of samples to generate. Default is 500.
Returns:
RegressionData
: A named tuple containing y_true, y_pred, and y_std.
class RegressionData
Regression data containing: y_true, y_pred and y_std.
module enums
Enums used as keys/accessors for dicts and dataframes across Matbench Discovery.
Global Variables
- TYPE_CHECKING
- eV_per_atom
- eV
- eV_per_angstrom
- eV_per_kelvin
- angstrom
- angstrom_per_atom
- cubic_angstrom
- degrees
- gram_per_cm3
- kelvin
- pascal
- giga_pascal
- joule
- joule_per_mol
- joule_per_m2
class LabelEnum
StrEnum with optional label and description attributes plus dict() methods.
Simply add label and description as a tuple starting with the key's value.
class Key
Keys used to access dataframes columns, organized by semantic groups.
class Task
What kind of task is being performed.
class Model
Model names.
class ElemCountMode
Mode of counting elements in a chemical formula.
class ElemColorMode
Mode of coloring elements in structure visualizations or periodic table plots.
class ElemColorScheme
Names of element color palettes.
Used e.g. in structure visualizations and periodic table plots.
class SiteCoords
Site coordinate representations.
module histogram
Histograms.
Global Variables
- TYPE_CHECKING
- BACKENDS
- MATPLOTLIB
- PLOTLY
function spacegroup_hist
spacegroup_hist(*args: 'Any', **kwargs: 'Any') → Axes | Figure
Alias for spacegroup_bar.
function elements_hist
elements_hist(
formulas: 'ElemValues',
count_mode: 'ElemCountMode' = Composition,
log: 'bool' = False,
keep_top: 'int | None' = None,
ax: 'Axes | None' = None,
bar_values: "Literal['percent', 'count'] | None" = 'percent',
h_offset: 'int' = 0,
v_offset: 'int' = 10,
rotation: 'int' = 45,
fontsize: 'int' = 12,
**kwargs: 'Any'
) → Axes
Plot a histogram of elements (e.g. to show occurrence in a dataset).
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).
Args:
formulas
(list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]. count_mode ("composition" | "fractional_composition" | "reduced_composition"): Reduce or normalize compositions before counting. Seecount_elements
for details. Only used when formulas is list of composition strings/objects.log
(bool, optional): Whether y-axis is log or linear. Defaults to False.keep_top
(int | None): Display only the top n elements by prevalence.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.bar_values
("percent"|"count"|None): "percent" (default) annotates bars with the percentage each element makes up in the total element count. "count" displays count itself. None removes bar labels.h_offset
(int): Horizontal offset for bar height labels. Defaults to 0.v_offset
(int): Vertical offset for bar height labels. Defaults to 10.rotation
(int): Bar label angle. Defaults to 45.fontsize
(int): Font size for bar labels. Defaults to 12.**kwargs (int)
: Keyword arguments passed to pandas.Series.plot.bar().
Returns:
plt.Axes
: matplotlib Axes object
function histogram
histogram(
values: 'Sequence[float] | dict[str, Sequence[float]]',
bins: 'int | Sequence[float] | str' = 200,
x_range: 'tuple[float | None, float | None] | None' = None,
density: 'bool' = False,
bin_width: 'float' = 1.2,
log_y: 'bool' = False,
backend: 'Backend' = 'plotly',
fig_kwargs: 'dict[str, Any] | None' = None,
**kwargs: 'Any'
) → Figure | Figure
Get a histogram with plotly (default) or matplotlib backend but using fast numpy pre-processing before handing the data off to the plot function.
Very common use case when dealing with large datasets so worth having a dedicated function for it. Two advantages over the matplotlib/plotly native histograms are much faster and much smaller file sizes (when saving plotly figs as HTML since plotly saves a complete copy of the data to disk from which it recomputes the histogram on the fly to render the figure). Speedup example:
gaussian = np.random.normal(0, 1, 1_000_000_000) plot_histogram(gaussian) # takes 17s px.histogram(gaussian) # ran for 3m45s before crashing the Jupyter kernel
Args:
values
(Sequence[float] or dict[str, Sequence[float]]): The values to plot as a histogram. If a dict is provided, the keys are used as legend labels.bins
(int or sequence, optional): The number of bins or the bin edges to use for the histogram. If not provided, a default value will be used.x_range
(tuple, optional): The range of values to include in the histogram. If not provided, the whole range of values will be used. Defaults to None.density
(bool, optional): Whether to normalize the histogram. Defaults to False.bin_width
(float, optional): The width of the histogram bins as a fraction of distance between bin edges. Defaults to 1.2 (20% overlap).log_y
(bool, optional): Whether to log scale the y-axis. Defaults to False.backend
(str, optional): The plotting backend to use. Can be either 'matplotlib' or 'plotly'. Defaults to 'plotly'.fig_kwargs
(dict, optional): Additional keyword arguments to pass to the figure creation function (plt.figure for Matplotlib or go.Figure for Plotly).**kwargs
: Additional keyword arguments to pass to the plotting function (plt.bar for Matplotlib or go.Figure.add_bar for Plotly).
Returns:
plt.Figure | go.Figure
: The figure object containing the histogram.
module io
I/O utilities for saving figures and dataframes to various image formats.
Global Variables
- DEFAULT_BUFFER_SIZE
- SEEK_SET
- SEEK_CUR
- SEEK_END
- TYPE_CHECKING
- ROOT
- DEFAULT_DF_STYLES
- ALLOW_TABLE_SCROLL
- HIDE_SCROLL_BAR
function save_fig
save_fig(
fig: 'Figure | Figure | Axes',
path: 'str',
plotly_config: 'dict[str, Any] | None' = None,
env_disable: 'Sequence[str]' = ('CI',),
pdf_sleep: 'float' = 0.6,
style: 'str' = '',
prec: 'int | None' = None,
template: 'str | None' = None,
transparent_bg: 'bool' = True,
**kwargs: 'Any'
) → None
Write a plotly or matplotlib figure to disk (as HTML/PDF/SVG/...).
If the file is has .svelte extension, insert {...$$props}
into the figure's top-level div so it can be later styled and customized from Svelte code.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): Plotly or matplotlib Figure or matplotlib Axes object.path
(str): Path to image file that will be created.plotly_config
(dict, optional): Configuration options for fig.write_html(). Defaults to dict(showTips=False, responsive=True, modeBarButtonsToRemove= ["lasso2d", "select2d", "autoScale2d", "toImage"]).See https
: //plotly.com/python/configuration-options.env_disable
(list[str], optional): Do nothing if any of these environment variables are set. Defaults to ("CI",).pdf_sleep
(float, optional): Minimum time in seconds to wait before writing a plotly figure to PDF file. Workaround for this plotly issuehttps
: //github.com/plotly/plotly.py/issues/3469. Defaults to 0.6. Has no effect on matplotlib figures.style
(str, optional): CSS style string to be inserted into the HTML file. Defaults to "". Only used if path ends with .svelte or .html.prec
(int, optional): Number of significant digits to keep for any float in the exported file. Defaults to None (no rounding). Sensible values are usually 4, 5, 6.template
(str, optional): Temporary plotly to apply to the figure before saving. Will be reset to the original after. Defaults to "pymatviz_white" if path ends with .pdf or .pdfa, else None. Set to None to disable. Only used if fig is a plotly figure.transparent_bg
(bool): Whether to save matplotlib figures with transparent background. Use False to show background colors.**kwargs
: Keyword arguments passed to fig.write_html().
function save_and_compress_svg
save_and_compress_svg(
fig: 'Figure | Figure | Axes',
filename: 'str',
transparent_bg: 'bool' = True
) → None
Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with svgo CLI if available in PATH.
If filename does not include .svg extension and is not absolute, will be treated as relative to assets/ folder. This function is mostly meant for pymatviz internal use.
Args:
fig
(Figure): Plotly or matplotlib Figure/Axes instance.filename
(str): Name of SVG file (w/o extension).transparent_bg
(bool): Whether to save matplotlib figures with transparent background. Use False to show background colors.
Raises:
ValueError
: If fig is None and plt.gcf() is empty.
function df_to_pdf
df_to_pdf(
styler: 'Styler',
file_path: 'str | Path',
crop: 'bool' = True,
size: 'str | None' = None,
style: 'str' = '',
styler_css: 'bool | dict[str, str]' = True,
**kwargs: 'Any'
) → None
Export a pandas Styler to PDF with WeasyPrint.
Args:
styler
(Styler): Styler object to export.file_path
(str): Path to save the PDF to. Requires WeasyPrint.crop
(bool): Whether to crop the PDF margins. Requires pdfCropMargins. Defaults to True. Be careful to set size correctly (not much too large as is the default) if you set crop=False.size
(str): Page size. Defaults to "4cm * n_cols x 2cm * n_rows"(width x height). See https
: //developer.mozilla.org/@page for 'landscape' and other special values.style
(str): CSS style string to be inserted into the HTML file. Defaults to "".styler_css
(bool | dict[str, str]): Whether to apply some sensible default CSS to the pandas Styler. Defaults to True. If dict, keys are selectors andvalues CSS strings. Example
: dict("td, th": "border: none; padding: 4px;")**kwargs
: Keyword arguments passed to Styler.to_html().
function normalize_and_crop_pdf
normalize_and_crop_pdf(
file_path: 'str | Path',
on_gs_not_found: "Literal['ignore', 'warn', 'error']" = 'warn'
) → None
Normalize a PDF using Ghostscript and then crop it. Without gs normalization, pdfCropMargins sometimes corrupts the PDF.
Args:
file_path
(str | Path): Path to the PDF file.on_gs_not_found
("ignore" | "warn" | "error", optional): What to do if Ghostscript is not found in PATH. Defaults to "warn".
function df_to_html_table
df_to_html_table(*args: 'Any', **kwargs: 'Any') → str
function df_to_html
df_to_html(
styler: 'Styler',
file_path: 'str | Path | None' = None,
inline_props: 'str | None' = '{{...$$props}} ',
pre_table: 'str | None' = '',
styles: 'str | None' = 'table { overflow: scroll; max-width: 100%; display: block; }\ntable {\n scrollbar-width: none; /* Firefox */\n}\ntable::-webkit-scrollbar {\n display: none; /* Safari and Chrome */\n}',
styler_css: 'bool | dict[str, str]' = True,
use_sortable: 'bool' = True,
use_tooltips: 'bool' = True,
post_process: 'Callable[[str], str] | None' = None,
**kwargs: 'Any'
) → str
Convert a pandas Styler to a svelte table.
Args:
styler
(Styler): Styler object to export.file_path
(str): Path to the file to write the svelte table to.inline_props
(str): Inline props to pass to the table element. Example:"class='table' style='width
: 100%'". Defaults to "".pre_table
(str): HTML string to insert above the table. Defaults to "". Will replace the opening table tag to allow passing props to it.styles
(str): CSS rules to insert at the bottom of the style tag. Defaults to TABLE_SCROLL_CSS.styler_css
(bool | dict[str, str]): Whether to apply some sensible default CSS to the pandas Styler. Defaults to True. If dict, keys are CSS selectors and values CSS strings. Example:dict("td, th"
: "border: none; padding: 4px 6px;")use_sortable
(bool): Whether to enable sorting the table by clicking on column headers. Defaults to True. Requires npm install svelte-zoo.use_tooltips
(bool): Whether to enable tooltips on table headers. Defaults to True. Requires npm install svelte-zoo.post_process
(Callable[[str], str]): Function to post-process the HTML string before writing it to file. Defaults to None.**kwargs
: Keyword arguments passed to Styler.to_html().
Returns:
str
: pandas Styler as HTML.
function df_to_svg
df_to_svg(
obj: 'DataFrame | Styler',
file_path: 'str | Path',
font_size: 'int' = 14,
compress: 'bool' = True,
**kwargs: 'Any'
) → Figure
Export a pandas DataFrame or Styler to an SVG file and optionally compress it.
TODO The SVG output has annoying margins that proved hard to remove. The goal is for this function to auto-crop the SVG viewBox to the content in the future.
Args:
obj
(DataFrame | Styler): DataFrame or Styler object to save as SVG.file_path
(str | Path): Where to save the SVG file.font_size
(int): Font size in points. Defaults to 14.compress
(bool): Whether to compress the SVG file using svgo. Defaults to True. svgo must be available in PATH.**kwargs
: Passed to matplotlib.figure.Figure.savefig().
Returns:
Figure
: Matplotlib Figure conversion of the DataFrame or Styler.
Raises:
subprocess.CalledProcessError
: If SVG compression fails.
class BufferedIOBase
Base class for buffered IO objects.
The main difference with RawIOBase is that the read() method supports omitting the size argument, and does not have a default implementation that defers to readinto().
In addition, read(), readinto() and write() may raise BlockingIOError if the underlying raw stream is in non-blocking mode and not ready; unlike their raw counterparts, they will never return None.
A typical implementation should not inherit from a RawIOBase implementation, but wrap one.
class IOBase
The abstract base class for all I/O classes.
This class provides dummy implementations for many methods that derived classes can override selectively; the default implementations represent a file that cannot be read, written or seeked.
Even though IOBase does not declare read, readinto, or write because their signatures will vary, implementations and clients should consider those methods part of the interface. Also, implementations may raise UnsupportedOperation when operations they do not support are called.
The basic type used for binary data read from or written to a file is bytes. Other bytes-like objects are accepted as method arguments too. In some cases (such as readinto), a writable object is required. Text I/O classes work with str data.
Note that calling any method (except additional calls to close(), which are ignored) on a closed stream should raise a ValueError.
IOBase (and its subclasses) support the iterator protocol, meaning that an IOBase object can be iterated over yielding the lines in a stream.
IOBase also supports the :keyword:with
statement. In this example, fp is closed after the suite of the with statement is complete:
with open('spam.txt', 'r') as fp: fp.write('Spam and eggs!')
class RawIOBase
Base class for raw binary I/O.
class TextIOBase
Base class for text I/O.
This class provides a character and line based interface to stream I/O. There is no readinto method because Python's character strings are immutable.
class UnsupportedOperation
class TqdmDownload
Progress bar for urllib.request.urlretrieve file download.
Adapted from official TqdmUpTo example. See https://github.com/tqdm/tqdm/blob/4c956c20b83be4312460fc0c4812eeb3fef5e7df/README.rst#hooks-and-callbacks
method __init__
__init__(*args: 'Any', **kwargs: 'Any') → None
Sets default values appropriate for file downloads for unit, unit_scale, unit_divisor, miniters, desc.
property format_dict
Public API for read-only member access.
method update_to
update_to(
n_blocks: 'int' = 1,
block_size: 'int' = 1,
total_size: 'int | None' = None
) → bool | None
Update hook for urlretrieve.
Args:
n_blocks
(int, optional): Number of blocks transferred so far. Default = 1.block_size
(int, optional): Size of each block (in tqdm units). Default = 1.total_size
(int, optional): Total size (in tqdm units). If None, remains unchanged. Defaults to None.
Returns:
bool | None
: True if tqdm.display() was triggered.
module phonons
Plotting functions for pymatgen phonon band structures and density of states.
Global Variables
- TYPE_CHECKING
function pretty_sym_point
pretty_sym_point(symbol: 'str') → str
Convert a symbol to a pretty-printed version.
function get_band_xaxis_ticks
get_band_xaxis_ticks(
band_struct: 'PhononBands',
branches: 'Sequence[str] | set[str]' = ()
) → tuple[list[float], list[str]]
Get all ticks and labels for a band structure plot.
Returns:
tuple[list[float], list[str]]
: Ticks and labels for the x-axis of a band structure plot.branches
(Sequence[str]): Branches to plot. Defaults to empty tuple, meaning all branches are plotted.
function phonon_bands
phonon_bands(
band_structs: 'PhononBands | dict[str, PhononBands]',
line_kwargs: "dict[str, Any] | dict[Literal['acoustic', 'optical'], dict[str, Any]] | Callable[[ndarray, int], dict[str, Any]] | None" = None,
branches: 'Sequence[str]' = (),
branch_mode: 'BranchMode' = 'union',
shaded_ys: 'dict[tuple[YMin | YMax, YMin | YMax], dict[str, Any]] | bool | None' = None,
**kwargs: 'Any'
) → Figure
Plot single or multiple pymatgen band structures using Plotly.
Warning: Only tested with phonon band structures so far but plan is to extend to electronic band structures.
Args: band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.
line_kwargs
(dict | dict[str, dict] | Callable): Line style configuration. Can be: - A single dict applied to all lines - A dict with keys "acoustic" and "optical" containing style dicts for each mode type - A callable taking (band_data, band_idx) and returning a style dict Common style options include color, width, dash. Defaults to None.branches
(Sequence[str]): Branches to plot. Defaults to empty tuple, meaning all branches are plotted.branch_mode
("union" | "intersection"): Whether to plot union or intersection of branches in case of multiple band structures with non-overlapping branches. Defaults to "union".shaded_ys
(dict[tuple[float | str, float | str], dict]): Keys are y-ranges (min, max) tuple and values are kwargs for shaded regions created by fig.add_hrect(). Defaults to single entry (0, "y_min"): dict(fillcolor="gray", opacity=0.07). "y_min" and "y_max" will be replaced with the figure's y-axis range. dict(layer="below", row="all", col="all") is always passed to add_hrect but can be overridden by the user. Set to False to disable.**kwargs
: Passed to Plotly's Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
function phonon_dos
phonon_dos(
doses: 'PhononDos | dict[str, PhononDos]',
stack: 'bool' = False,
sigma: 'float' = 0,
units: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz',
normalize: "Literal['max', 'sum', 'integral'] | None" = None,
last_peak_anno: 'str | None' = None,
**kwargs: 'Any'
) → Figure
Plot phonon DOS using Plotly.
Args:
doses
(PhononDos | dict[str, PhononDos]): PhononDos or dict of multiple.stack
(bool): Whether to plot the DOS as a stacked area graph. Defaults to False.sigma
(float): Standard deviation for Gaussian smearing. Defaults to None.units
(str): Units for the frequencies. Defaults to "THz".legend
(dict): Legend configuration.normalize
(bool): Whether to normalize the DOS. Defaults to False.last_peak_anno
(str): Annotation for last DOS peak with f-string placeholders for key (of dict containing multiple DOSes), last_peak frequency and units. Defaults to None, meaning last peak annotation is disabled. Set to "" to enable with a sensible default string.**kwargs
: Passed to Plotly's Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
function convert_frequencies
convert_frequencies(
frequencies: 'ndarray',
unit: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz'
) → ndarray
Convert frequencies from THz to specified units.
Args:
frequencies
(np.ndarray): Frequencies in THz.unit
(str): Target units. One of 'THz', 'eV', 'meV', 'Ha', 'cm-1'.
Returns:
np.ndarray
: Converted frequencies.
function phonon_bands_and_dos
phonon_bands_and_dos(
band_structs: 'PhononBands | dict[str, PhononBands]',
doses: 'PhononDos | dict[str, PhononDos]',
bands_kwargs: 'dict[str, Any] | None' = None,
dos_kwargs: 'dict[str, Any] | None' = None,
subplot_kwargs: 'dict[str, Any] | None' = None,
all_line_kwargs: 'dict[str, Any] | None' = None,
per_line_kwargs: 'dict[str, dict[str, Any]] | None' = None,
**kwargs: 'Any'
) → Figure
Plot phonon DOS and band structure using Plotly.
Args:
doses
(PhononDos | dict[str, PhononDos]): PhononDos or dict of multiple. band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.bands_kwargs
(dict[str, Any]): Passed to Plotly's Figure.add_scatter method.dos_kwargs
(dict[str, Any]): Passed to Plotly's Figure.add_scatter method.subplot_kwargs
(dict[str, Any]): Passed to Plotly's make_subplots method. Defaults to dict(shared_yaxes=True, column_widths=(0.8, 0.2), horizontal_spacing=0.01).all_line_kwargs
(dict[str, Any]): Passed to trace.update for each trace in fig.data. Modifies line appearance for all traces. Defaults to None.per_line_kwargs
(dict[str, str]): Map of line labels to kwargs for trace.update. Modifies line appearance for specific traces. Defaults to None.**kwargs
: Passed to Plotly's Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
class PhononDBDoc
Dataclass for phonon DB docs.
method __init__
__init__(
structure: 'Structure',
phonon_bandstructure: 'PhononBands',
phonon_dos: 'PhononDos',
free_energies: 'list[float]',
internal_energies: 'list[float]',
heat_capacities: 'list[float]',
entropies: 'list[float]',
temps: 'list[float] | None' = None,
has_imaginary_modes: 'bool | None' = None,
primitive: 'Structure | None' = None,
supercell: 'list[list[int]] | None' = None,
nac_params: 'dict[str, Any] | None' = None,
thermal_displacement_data: 'dict[str, Any] | None' = None,
mp_id: 'str | None' = None,
formula: 'str | None' = None
) → None
module powerups.both
Powerups that can be applied to both matplotlib and plotly figures.
Global Variables
- TYPE_CHECKING
- BACKENDS
- MATPLOTLIB
- PLOTLY
- VALID_FIG_NAMES
- VALID_FIG_TYPES
function annotate_metrics
annotate_metrics(
xs: 'ArrayLike',
ys: 'ArrayLike',
fig: 'AxOrFig | None' = None,
metrics: 'dict[str, float] | Sequence[str]' = ('MAE', 'R2'),
prefix: 'str' = '',
suffix: 'str' = '',
fmt: 'str' = '.3',
**kwargs: 'Any'
) → AnchoredText
Provide a set of x and y values of equal length and an optional Axes object on which to print the values' mean absolute error and R^2 coefficient of determination.
Args:
xs
(array): x values.ys
(array): y values.fig
(plt.Axes | plt.Figure | go.Figure | None, optional): matplotlib Axes or Figure or plotly Figure on which to add the annotation. Defaults to None.metrics
(dict[str, float] | Sequence[str], optional): Metrics to show. Can be a subset of recognized keys MAE, R2, R2_adj, RMSE, MSE, MAPE or the names of sklearn.metrics.regression functions or any dict of metric names and values. Defaults to ("MAE", "R2").prefix
(str, optional): Title or other string to prepend to metrics. Defaults to "".suffix
(str, optional): Text to append after metrics. Defaults to "".fmt
(str, optional): f-string float format for metrics. Defaults to '.3'.**kwargs
: Additional arguments to pass to annotate().
Returns:
plt.Axes | plt.Figure | go.Figure
: The annotated figure.
function add_identity_line
add_identity_line(
fig: 'Figure | Figure | Axes',
line_kwargs: 'dict[str, Any] | None' = None,
trace_idx: 'int' = 0,
retain_xy_limits: 'bool' = False,
**kwargs: 'Any'
) → Figure
Add a line shape to the background layer of a plotly figure spanning from smallest to largest x/y values in the trace specified by trace_idx.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes to add the identity line to.line_kwargs
(dict[str, Any], optional): Keyword arguments for customizing the line shape will be passed to fig.add_shape(line=line_kwargs). Defaults to dict(color="gray", width=1, dash="dash").trace_idx
(int, optional): Index of the trace to use for measuring x/y limits. Defaults to 0. Unused if kaleido package is installed and the figure's actual x/y-range can be obtained from fig.full_figure_for_development(). Applies only to plotly figures.retain_xy_limits
(bool, optional): If True, the x/y-axis limits will be retained after adding the identity line. Defaults to False.**kwargs
: Additional arguments are passed to fig.add_shape().
Raises:
TypeError
: If fig is neither a plotly nor a matplotlib figure or axes.ValueError
: If fig is a plotly figure and kaleido is not installed and trace_idx is out of range.
Returns:
Figure
: Figure with added identity line.
function add_best_fit_line
add_best_fit_line(
fig: 'Figure | Figure | Axes',
xs: 'ArrayLike' = (),
ys: 'ArrayLike' = (),
trace_idx: 'int | None' = None,
line_kwargs: 'dict[str, Any] | None' = None,
annotate_params: 'bool | dict[str, Any]' = True,
warn: 'bool' = True,
**kwargs: 'Any'
) → Figure
Add line of best fit according to least squares to a plotly or matplotlib figure.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes to add the best fit line to.xs
(array, optional): x-values to use for fitting. Defaults to () which means use the x-values of trace at trace_idx in fig.ys
(array, optional): y-values to use for fitting. Defaults to () which means use the y-values of trace at trace_idx in fig.trace_idx
(int, optional): Index of the trace to use for measuring x/y values for fitting if xs and ys are not provided. Defaults to 0.line_kwargs
(dict[str, Any], optional): Keyword arguments for customizing the line shape. For plotly, will be passed to fig.add_shape(line=line_kwargs). For matplotlib, will be passed to ax.plot(). Defaults to None.annotate_params
(dict[str, Any], optional): Pass dict to customize the annotation of the best fit line. Set to False to disable annotation. Defaults to True.warn
(bool, optional): If True, print a warning if trace_idx is unspecified and the figure has multiple traces. Defaults to True.**kwargs
: Additional arguments are passed to fig.add_shape() for plotly or ax.plot() for matplotlib.
Raises:
TypeError
: If fig is neither a plotly nor a matplotlib figure or axes.ValueError
: If fig is a plotly figure and xs and ys are not provided and trace_idx is out of range.
Returns:
Figure
: Figure with added best fit line.
module powerups.matplotlib
Powerups for matplotlib figures.
Global Variables
- TYPE_CHECKING
function with_marginal_hist
with_marginal_hist(
xs: 'ArrayLike',
ys: 'ArrayLike',
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
fig: 'Figure | Axes | None' = None,
**kwargs: 'Any'
) → Axes
Call before creating a matplotlib figure and use the returned ax_main
for all subsequent plotting ops to create a grid of plots with the main plot in the lower left and narrow histograms along its x- and/or y-axes displayed above and near the right edge.
Args:
xs
(array): Marginal histogram values along x-axis.ys
(array): Marginal histogram values along y-axis.cell
(GridSpec, optional): Cell of a plt GridSpec at which to add the grid of plots. Defaults to None.bins
(int, optional): Resolution/bin count of the histograms. Defaults to 100.fig
(Figure, optional): matplotlib Figure or Axes to add the marginal histograms to. Defaults to None.**kwargs
: Additional keywords passed to ax.hist().
Returns:
plt.Axes
: The matplotlib Axes to be used for the main plot.
function annotate_bars
annotate_bars(
ax: 'Axes | None' = None,
v_offset: 'float' = 10,
h_offset: 'float' = 0,
labels: 'Sequence[str | int | float] | None' = None,
fontsize: 'int' = 14,
y_max_headroom: 'float' = 1.2,
adjust_test_pos: 'bool' = False,
**kwargs: 'Any'
) → None
Annotate each bar in bar plot with a label.
Args:
ax
(Axes): The matplotlib axes to annotate.v_offset
(int): Vertical offset between the labels and the bars.h_offset
(int): Horizontal offset between the labels and the bars.labels
(list[str]): Labels used for annotating bars. If not provided, defaults to the y-value of each bar.fontsize
(int): Annotated text size in pts. Defaults to 14.y_max_headroom
(float): Will be multiplied with the y-value of the tallest bar to increase the y-max of the plot, thereby making room for text above all bars. Defaults to 1.2.adjust_test_pos
(bool): If True, use adjustText to prevent overlapping labels. Defaults to False.**kwargs
: Additional arguments (rotation, arrowprops, etc.) are passed to ax.annotate().
module powerups
Powerups such as parity lines, annotations, marginals, menu buttons, etc for matplotlib and plotly figures.
Global Variables
- select_colorscale
- select_marker_mode
- toggle_grid
- toggle_log_linear_x_axis
- toggle_log_linear_y_axis
module powerups.plotly
Powerups for plotly figures.
Global Variables
- TYPE_CHECKING
- toggle_log_linear_y_axis
- toggle_log_linear_x_axis
- toggle_grid
- select_colorscale
- select_marker_mode
function add_ecdf_line
add_ecdf_line(
fig: 'Figure',
values: 'ArrayLike' = (),
trace_idx: 'int' = 0,
trace_kwargs: 'dict[str, Any] | None' = None,
**kwargs: 'Any'
) → Figure
Add an empirical cumulative distribution function (ECDF) line to a plotly figure.
Support for matplotlib planned but not implemented. PRs welcome.
Args:
fig
(go.Figure): plotly figure to add the ECDF line to.values
(array, optional): Values to compute the ECDF from. Defaults to () which means use the x-values of trace at trace_idx in fig.trace_idx
(int, optional): Index of the trace whose x-values to use for computing the ECDF. Defaults to 0. Unused if values is not empty.trace_kwargs
(dict[str, Any], optional): Passed to trace_ecdf.update(). Defaults to None. Use e.g. to set trace name (default "Cumulative") or line_color (default "gray").**kwargs
: Passed to fig.add_trace().
Returns:
Figure
: Figure with added ECDF line.
module process_data
pymatviz utility functions.
Global Variables
- TYPE_CHECKING
function count_elements
count_elements(
values: 'ElemValues',
count_mode: 'ElemCountMode' = Composition,
exclude_elements: 'Sequence[str]' = (),
fill_value: 'float | None' = None
) → Series
Count element occurrence in list of formula strings or dict-like compositions.
If passed values are already a map from element symbol to counts, ensure the data is a pd.Series filled with "fill_value" for missing element.
Provided as standalone function for external use or to cache long computations. Caching long element counts is done by refactoring: ptable_heatmap(long_list_of_formulas) # slow to: elem_counts = count_elements(long_list_of_formulas) # slow ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot
Args:
values
(dict[str, int | float] | pd.Series | list[str]): Iterable of composition strings/objects or map from element symbols to heatmap values. count_mode ('(element|fractional|reduced)_composition'): Only used when values is a list of composition strings/objects. - composition (default): Count elements in each composition as is, i.e. without reduction or normalization. - fractional_composition: Convert to normalized compositions in which the amounts of each species sum to before counting.Example
: Fe2 O3 -> Fe0.4 O0.6 - reduced_composition: Convert to reduced compositions (i.e. amounts normalized by greatest common denominator) before counting.Example
: Fe4 P4 O16 -> Fe P O4. - occurrence: Count the number of times each element occurs in a list of formulas irrespective of compositions. E.g. [Fe2 O3, Fe O, Fe4 P4 O16]counts to {Fe
: 3, O: 3, P: 1}.exclude_elements
(Sequence[str]): Elements to exclude from the count. Defaults to ().fill_value
(float | None): Value to fill in for missing elements. Defaults to None for NaN.
Returns:
pd.Series
: Map element symbols to heatmap values.
module ptable
matplotlib and plotly periodic table figures.
module ptable.ptable_matplotlib
Various periodic table heatmaps with matplotlib and plotly.
Global Variables
- TYPE_CHECKING
function ptable_heatmap
ptable_heatmap(
data: 'DataFrame | Series | dict[str, list[list[float]]] | PTableData',
colormap: 'str' = 'viridis',
exclude_elements: 'Sequence[str]' = (),
overwrite_tiles: 'dict[ElemStr, OverwriteTileValueColor] | None' = None,
infty_color: 'ColorType' = 'lightskyblue',
nan_color: 'ColorType' = 'lightgrey',
log: 'bool' = False,
sci_notation: 'bool' = False,
tile_size: 'tuple[float, float]' = (0.75, 0.75),
on_empty: "Literal['hide', 'show']" = 'show',
hide_f_block: "bool | Literal['auto']" = 'auto',
f_block_voffset: 'float' = 0,
plot_kwargs: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
text_colors: "Literal['auto'] | ColorType | dict[ElemStr, ColorType]" = 'auto',
symbol_pos: 'tuple[float, float] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None,
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
anno_kwargs: 'dict[str, Any] | None' = None,
value_show_mode: "Literal['value', 'fraction', 'percent', 'off']" = 'value',
value_pos: 'tuple[float, float] | None' = None,
value_fmt: 'str' = 'auto',
value_kwargs: 'dict[str, Any] | None' = None,
show_cbar: 'bool' = True,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.05),
cbar_range: 'tuple[float | None, float | None]' = (None, None),
cbar_label_fmt: 'str' = 'auto',
cbar_title: 'str' = 'Element Count',
cbar_title_kwargs: 'dict[str, Any] | None' = None,
cbar_kwargs: 'dict[str, Any] | None' = None,
return_type: "Literal['figure', 'axes']" = 'axes',
colorscale: 'str | None' = None,
heat_mode: "Literal['value', 'fraction', 'percent'] | None" = None,
show_values: 'bool | None' = None,
fmt: 'str | None' = None,
cbar_fmt: 'str | None' = None,
show_scale: 'bool | None' = None
) → Axes
Plot a heatmap across the periodic table.
Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
-
{"Fe"
: [1, 2], "Co": [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.--- Heatmap ---
-
colormap
(str): The colormap to use. -
exclude_elements
(Sequence[str]): Elements to exclude. -
overwrite_tiles
(dict[ElemStr, OverwriteTileValueColor]): Force overwrite value or color for element tiles. -
infty_color
(ColorType): The color to use for infinity. -
nan_color
(ColorType): The color to use for missing value (NaN). -
log
(bool): Whether to show colorbar in log scale. -
sci_notation
(bool): Whether to use scientific notation for values and colorbar tick labels. -
tile_size
(tuple[float, float]): The relative height and width of the tile.--- Figure ---
-
on_empty
("hide" | "show"): Whether to show or hide tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
f_block_voffset
(float): The vertical offset of f-block elements. -
plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.--- Axis ---
-
ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title="Periodic Table", xlabel="x-axis", ylabel="y-axis", xlim=(0, 10), ylim=(0, 10), xscale="linear", yscale="log"). See ax.set() docs for options. -
text_colors
: Colors for element symbols and values. - "auto": Auto pick "black" or "white" based on the contrast of tile color for each element. - ColorType: Use the same ColorType for each element. - dict[ElemStr, ColorType]: Element to color mapping.--- Symbol ---
-
symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner. -
symbol_kwargs
(dict): Keyword arguments passed to plt.text() for element symbols. Defaults to None.--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_text_color
(ColorType | dict[ElemStr, ColorType]): Texts colors for annotations. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Value ---
-
value_show_mode
(str): The values display mode: - "off": Hide values. - "value": Display values as is. - "fraction": As a fraction of the total (0.10). - "percent": As a percentage of the total (10%). "fraction" and "percent" can be used to make the colors in different plots comparable. -
value_pos
(tuple[float, float]): The position of values inside the tile. -
value_fmt
(str | "auto"): f-string format for values. Defaults to ".1%" (1 decimal place) if values_show_mode is "percent", else ".3g". -
value_color
(str | "auto"): The font color of values. Use "auto" for automatically switch between black/white depending on the background. -
value_kwargs
(dict): Keyword arguments passed to plt.text() for values. Defaults to None.--- Colorbar ---
-
show_cbar
(bool): Whether to show colorbar. -
cbar_coords
(tuple[float, float, float, float]): Colorbar -
position and size
: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.18, 0.8, 0.42, 0.05). -
cbar_range
(tuple): Colorbar values display range, use None for auto detection for the corresponding boundary. -
cbar_label_fmt
(str): f-string format option for color tick labels. -
cbar_title
(str): Colorbar title. Defaults to "Values". -
cbar_title_kwargs
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10). -
cbar_kwargs
(dict): Keyword arguments passed to fig.colorbar().--- Migration --- TODO: remove after 2025-07-01
-
return_type
("figure" | "axes"): Whether to return plt.Figure or plt.axes. We encourage you to migrate to "figure".--- Deprecated args, don't use --- TODO: remove after 2025-04-01
-
colorscale
: Use "colormap" instead. -
heat_mode
: Use "value_show_mode" instead. -
show_values
: Use "value_show_mode" instead. -
fmt
: Use "value_fmt" instead. -
cbar_fmt
: Use "cbar_label_fmt" instead. -
show_scale
: Use "show_cbar" instead.
Returns:
plt.Axes
: matplotlib axes with the heatmap. orplt.Figure
: matplotlib figure with the heatmap.
function ptable_heatmap_ratio
ptable_heatmap_ratio(
values_num: 'ElemValues',
values_denom: 'ElemValues',
count_mode: 'ElemCountMode' = Composition,
normalize: 'bool' = False,
infty_color: 'ColorType' = 'lightskyblue',
zero_color: 'ColorType' = 'lightgrey',
zero_tol: 'float' = 1e-06,
zero_symbol: 'str' = 'ZERO',
not_in_numerator: 'tuple[str, str] | None' = ('lightgray', 'gray: not in 1st list'),
not_in_denominator: 'tuple[str, str] | None' = ('lightskyblue', 'blue: not in 2nd list'),
not_in_either: 'tuple[str, str] | None' = ('white', 'white: not in either'),
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
anno_kwargs: 'dict[str, Any] | None' = None,
cbar_title: 'str' = 'Element Ratio',
**kwargs: 'Any'
) → figure
Display the ratio of two maps from element symbols to heat values or of two sets of compositions.
Args:
--- Ratio heatmap specific ---
-
values_num
(dict[ElemStr, int | float] | pd.Series | list[ElemStr]): Map from element symbols to heatmap values or iterable of composition strings/objects in the numerator. -
values_denom
(dict[ElemStr, int | float] | pd.Series | list[ElemStr]): Map from element symbols to heatmap values or iterable of composition strings/objects in the denominator.--- Data preprocessing --- count_mode ("composition" | "fractional_composition" | "reduced_composition"): Reduce or normalize compositions before counting. See
count_elements
for details. Only used when values is list of composition strings/objects. -
normalize
(bool): Whether to normalize heatmap values so they sum to 1. Makes different ptable_heatmap_ratio plots comparable. Defaults to False.--- Infinity and zero handling ---
-
infty_color
(ColorType): Color for infinity. -
zero_color
(ColorType): Color for (near) zero element tiles. -
zero_tol
(float): Absolute tolerance to consider a value zero. -
zero_symbol
(str): Value to display for (near) zero element tiles.--- Colors and legends for special cases ---
-
not_in_numerator
(tuple[str, str]): Color and legend description used for elements missing from numerator. Defaults to -
("#eff", "gray
: not in 1st list"). -
not_in_denominator
(tuple[str, str]): See not_in_numerator. Defaults to -
("lightskyblue", "blue
: not in 2nd list"). -
not_in_either
(tuple[str, str]): See not_in_numerator. Defaults to -
("white", "white
: not in either").--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_text_color
(ColorType | dict[ElemStr, ColorType]): Texts colors for annotations. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Colorbar ---
-
cbar_title
(str): Title for the colorbar. Defaults to "Element Ratio".--- Others ---
-
**kwargs
: Additional keyword arguments passed to ptable_heatmap().
Returns:
plt.Figure
: matplotlib Figure object.
function ptable_heatmap_splits
ptable_heatmap_splits(
data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
start_angle: 'float' = 135,
colormap: 'str | Colormap' = 'viridis',
on_empty: "Literal['hide', 'show']" = 'hide',
hide_f_block: "bool | Literal['auto']" = 'auto',
plot_kwargs: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda104237e0>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.5),
symbol_kwargs: 'dict[str, Any] | None' = None,
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
anno_kwargs: 'dict[str, Any] | None' = None,
cbar_title: 'str' = 'Values',
cbar_title_kwargs: 'dict[str, Any] | None' = None,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
cbar_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plot evenly-split heatmaps, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
-
{"Fe"
: [1, 2], "Co": [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.--- Heatmap-split ---
-
start_angle
(float): The starting angle for the splits in degrees, and the split proceeds counter-clockwise (0 refers to the x-axis). Defaults to 135 degrees.--- Figure ---
-
colormap
(str): Matplotlib colormap name to use. -
on_empty
("hide" | "show"): Whether to show or hide tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.--- Axis ---
-
ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title="Periodic Table", xlabel="x-axis", ylabel="y-axis", xlim=(0, 10), ylim=(0, 10), xscale="linear", yscale="log"). See ax.set() docs for options.--- Symbol ---
-
symbol_text
(str | Callable[[Element], str]): Text to display for -
each element symbol. Defaults to lambda elem
: elem.symbol. -
symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner. -
symbol_kwargs
(dict): Keyword arguments passed to ax.text() for element symbols. Defaults to None.--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_text_color
(ColorType | dict[ElemStr, ColorType]): Texts colors for annotations. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Colorbar ---
-
cbar_title
(str): Colorbar title. Defaults to "Values". -
cbar_title_kwargs
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10). -
cbar_coords
(tuple[float, float, float, float]): Colorbar -
position and size
: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.25, 0.77, 0.35, 0.02). -
cbar_kwargs
(dict): Keyword arguments passed to fig.colorbar().
Notes:
Default figsize is set to (0.75 * n_groups, 0.75 * n_periods).
Returns:
plt.Figure
: periodic table with a subplot in each element tile.
function ptable_hists
ptable_hists(
data: 'DataFrame | Series | dict[ElemStr, list[float]]',
bins: 'int' = 20,
x_range: 'tuple[float | None, float | None] | None' = None,
log: 'bool' = False,
colormap: 'str | Colormap | None' = 'viridis',
on_empty: "Literal['show', 'hide']" = 'hide',
hide_f_block: "bool | Literal['auto']" = 'auto',
plot_kwargs: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
child_kwargs: 'dict[str, Any] | None' = None,
cbar_axis: "Literal['x', 'y']" = 'x',
cbar_title: 'str' = 'Values',
cbar_title_kwargs: 'dict[str, Any] | None' = None,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
cbar_kwargs: 'dict[str, Any] | None' = None,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423920>,
symbol_kwargs: 'dict[str, Any] | None' = None,
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_kwargs: 'dict[str, Any] | None' = None,
color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
elem_type_colors: 'dict[str, str] | None' = None,
add_elem_type_legend: 'bool' = False,
elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plot histograms for each element laid out in a periodic table.
Args:
-
data
(pd.DataFrame | pd.Series | dict[ElemStr, list[float]]): Map from element -
symbols to histogram values. E.g. if dict, {"Fe"
: [1, 2, 3], "O": [4, 5]}. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols histograms are plotted from each column.--- Histogram-specific ---
-
bins
(int): Number of bins for the histograms. Defaults to 20. -
x_range
(tuple[float | None, float | None]): x-axis range for all histograms. Defaults to None. -
log
(bool): Whether to log scale y-axis of each histogram. Defaults to False.--- Figure ---
-
colormap
(str): Matplotlib colormap name to use. Defaults to "viridis". See -
options at https
: //matplotlib.org/stable/users/explain/colors/colormaps. -
on_empty
("hide" | "show"): Whether to show or hide tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.--- Axis ---
-
ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title="Periodic Table", xlabel="x-axis", ylabel="y-axis", xlim=(0, 10), ylim=(0, 10), xscale="linear", yscale="log"). See ax.set() docs for options: -
https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-set -
child_kwargs
(dict): Keywords passed to ax.hist() for each histogram. Defaults to None.--- Colorbar ---
-
cbar_axis
("x" | "y"): The axis colormap would be based on. -
cbar_title
(str): Color bar title. Defaults to "Histogram Value". -
cbar_title_kwargs
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10). -
cbar_coords
(tuple[float, float, float, float]): Color bar position and size: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.25, 0.77, 0.35, 0.02). -
cbar_kwargs
(dict): Keyword arguments passed to fig.colorbar().--- Symbol ---
-
symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.8). (1, 1) is the upper right corner. -
symbol_text
(str | Callable[[Element], str]): Text to display for each element -
symbol. Defaults to lambda elem
: elem.symbol. -
symbol_kwargs
(dict): Keyword arguments passed to ax.text() for element symbols. Defaults to None.--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Element type colors and legend ---
-
color_elem_strategy
("symbol" | "background" | "both" | "off"): Whether to color element symbols, tile backgrounds, or both based on element type. Defaults to "background". -
elem_type_colors
(dict | None): dict to map element types to colors. None to use default element type colors. -
add_elem_type_legend
(bool): Whether to show a legend for element types. Defaults to True. -
elem_type_legend_kwargs
(dict): kwargs to plt.legend(), e.g. to -
set the legend title, use {"title"
: "Element Types"}.
Returns:
plt.Figure
: periodic table with a histogram in each element tile.
function ptable_scatters
ptable_scatters(
data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
colormap: 'str | Colormap | None' = None,
on_empty: "Literal['hide', 'show']" = 'hide',
hide_f_block: "bool | Literal['auto']" = 'auto',
plot_kwargs: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
child_kwargs: 'dict[str, Any] | None' = None,
cbar_title: 'str' = 'Values',
cbar_title_kwargs: 'dict[str, Any] | None' = None,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
cbar_kwargs: 'dict[str, Any] | None' = None,
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423a60>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
symbol_kwargs: 'dict[str, Any] | None' = None,
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_kwargs: 'dict[str, Any] | None' = None,
color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
elem_type_colors: 'dict[str, str] | None' = None,
add_elem_type_legend: 'bool' = False,
elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure
Make scatter plots for each element, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
-
{"Fe"
: [1, 2], "Co": [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.--- Figure ---
-
colormap
(str): Matplotlib colormap name to use. Defaults to None'. See -
options at https
: //matplotlib.org/stable/users/explain/colors/colormaps. -
on_empty
("hide" | "show"): Whether to show or hide tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.--- Axis ---
-
ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title="Periodic Table", xlabel="x-axis", ylabel="y-axis", xlim=(0, 10), ylim=(0, 10), xscale="linear", yscale="log"). See ax.set() docs for options: -
https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-set -
child_kwargs
: Arguments to pass to the child plotter call.--- Colorbar ---
-
cbar_title
(str): Color bar title. Defaults to "Histogram Value". -
cbar_title_kwargs
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10). -
cbar_coords
(tuple[float, float, float, float]): Color bar position and size: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.25, 0.77, 0.35, 0.02). -
cbar_kwargs
(dict): Keyword arguments passed to fig.colorbar().--- Symbol ---
-
symbol_text
(str | Callable[[Element], str]): Text to display for -
each element symbol. Defaults to lambda elem
: elem.symbol. -
symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner. -
symbol_kwargs
(dict): Keyword arguments passed to ax.text() for element symbols. Defaults to None.--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Element type colors and legend ---
-
color_elem_strategy
("symbol" | "background" | "both" | "off"): Whether to color element symbols, tile backgrounds, or both based on element type. Defaults to "background". -
elem_type_colors
(dict | None): dict to map element types to colors. None to use default element type colors. -
add_elem_type_legend
(bool): Whether to show a legend for element types. Defaults to True. -
elem_type_legend_kwargs
(dict): kwargs to plt.legend(), e.g. to -
set the legend title, use {"title"
: "Element Types"}.
function ptable_lines
ptable_lines(
data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
on_empty: "Literal['hide', 'show']" = 'hide',
hide_f_block: "bool | Literal['auto']" = 'auto',
plot_kwargs: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
child_kwargs: 'dict[str, Any] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None,
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423ba0>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
anno_pos: 'tuple[float, float]' = (0.75, 0.75),
anno_text: 'dict[ElemStr, str] | None' = None,
anno_kwargs: 'dict[str, Any] | None' = None,
color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
elem_type_colors: 'dict[str, str] | None' = None,
add_elem_type_legend: 'bool' = False,
elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure
Line plots for each element, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
-
{"Fe"
: [1, 2], "Co": [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.--- Figure ---
-
on_empty
("hide" | "show"): Whether to show or hide tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.--- Axis ---
-
ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title="Periodic Table", xlabel="x-axis", ylabel="y-axis", xlim=(0, 10), ylim=(0, 10), xscale="linear", yscale="log"). See ax.set() docs for options: -
https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-set -
child_kwargs
: Arguments to pass to the child plotter call.--- Symbol ---
-
symbol_text
(str | Callable[[Element], str]): Text to display for -
each element symbol. Defaults to lambda elem
: elem.symbol. -
symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner. -
symbol_kwargs
(dict): Keyword arguments passed to ax.text() for element symbols. Defaults to None.--- Annotation ---
-
anno_pos
(tuple[float, float]): Position of annotation relative to the lower left corner of each tile. Defaults to (0.75, 0.75). (1, 1) is the upper right corner. -
anno_text
(dict[ElemStr, str]): Annotation to display for each element tile. Defaults to None for not displaying. -
anno_kwargs
(dict): Keyword arguments passed to ax.text() for annotation. Defaults to None.--- Element type colors and legend ---
-
color_elem_strategy
("symbol" | "background" | "both" | "off"): Whether to color element symbols, tile backgrounds, or both based on element type. Defaults to "background". -
elem_type_colors
(dict | None): dict to map element types to colors. None to use default element type colors. -
add_elem_type_legend
(bool): Whether to show a legend for element types. Defaults to True. -
elem_type_legend_kwargs
(dict): kwargs to plt.legend(), e.g. to -
set the legend title, use {"title"
: "Element Types"}.
module ptable.ptable_plotly
Periodic table plots powered by plotly.
Global Variables
- TYPE_CHECKING
- ELEM_TYPE_COLORS
- VALID_COLOR_ELEM_STRATEGIES
function ptable_heatmap_plotly
ptable_heatmap_plotly(
values: 'ElemValues',
count_mode: 'ElemCountMode' = Composition,
colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
show_scale: 'bool' = True,
show_values: 'bool' = True,
heat_mode: "Literal['value', 'fraction', 'percent']" = 'value',
fmt: 'str | None' = None,
hover_props: 'Sequence[str] | dict[str, str] | None' = None,
hover_data: 'dict[str, str | int | float] | Series | None' = None,
font_colors: 'Sequence[str]' = (),
gap: 'float' = 5,
font_size: 'int | None' = None,
bg_color: 'str | None' = None,
nan_color: 'str' = '#eff',
colorbar: 'dict[str, Any] | None' = None,
cscale_range: 'tuple[float | None, float | None]' = (None, None),
exclude_elements: 'Sequence[str]' = (),
log: 'bool' = False,
fill_value: 'float | None' = None,
element_symbol_map: 'dict[str, str] | None' = None,
label_map: 'dict[str, str] | Callable[[str], str] | Literal[False] | None' = None,
border: 'dict[str, Any] | None | Literal[False]' = None,
scale: 'float' = 1.0,
**kwargs: 'Any'
) → Figure
Create a Plotly figure with an interactive heatmap of the periodic table. Supports hover tooltips with custom data or atomic reference data like electronegativity, atomic_radius, etc. See kwargs hover_data and hover_props, resp.
Args:
values
(dict[str, int | float] | pd.Series | list[str]): Map from element symbols to heatmap values e.g. dict(Fe=2, O=3) or iterable of composition strings or Pymatgen composition objects. count_mode ("composition" | "fractional_composition" | "reduced_composition"): Reduce or normalize compositions before counting. Seecount_elements
for details. Only used when values is list of composition strings/objects.colorscale
(str | list[str] | list[tuple[float, str]]): Color scale for heatmap. Defaults to "viridis". See plotly.com/python/builtin-colorscales for names of other builtin color scales. Note "YlGn" and px.colors.sequential.YlGn are equivalent. Custom scales are specified as ["blue", "red"] or [[0, "rgb(0,0,255)"], [0.5, "rgb(0,255,0)"], [1, "rgb(255,0,0)"]].show_scale
(bool): Whether to show a bar for the color scale. Defaults to True.show_values
(bool): Whether to show numbers on heatmap tiles. Defaults to True.heat_mode
("value" | "fraction" | "percent" | None): Whether to display heat values as is (value), normalized as a fraction of the total, as percentages or not at all (None). Defaults to "value". "fraction" and "percent" can be used to make the colors in different periodic table heatmap plots comparable.fmt
(str): f-string format option for heat labels. Defaults to ".1%" (1 decimal place) if heat_mode="percent" else ".3g".hover_props
(list[str] | dict[str, str]): Elemental properties to display in the hover tooltip. Can be a list of property names to display only the values themselves or a dict mapping names to what they should display as. E.g. dict(atomic_mass="atomic weight") will display as"atomic weight = {x}"
. Defaults to None.Available properties are
: symbol, row, column, name, atomic_number, atomic_mass, n_neutrons, n_protons, n_electrons, period, group, phase, radioactive, natural, metal, nonmetal, metalloid, type, atomic_radius, electronegativity, first_ionization, density, melting_point, boiling_point, number_of_isotopes, discoverer, year, specific_heat, n_shells, n_valence.hover_data
(dict[str, str | int | float] | pd.Series): Map from element symbols to additional data to display in the hover tooltip. dict(Fe="this appears in the hover tooltip on a new line below the element name"). Defaults to None.font_colors
(list[str]): One color name or two for [min_color, max_color]. min_color is applied to annotations with heatmap values less than (max_val - min_val) / 2. Defaults to None, meaning auto-set to maximizecontrast with color scale
: white text for dark background and vice versa. swapped depending on the colorscale.gap
(float): Gap in pixels between tiles of the periodic table. Defaults to 5.font_size
(int): Element symbol and heat label text size. Any valid CSS size allowed. Defaults to automatic font size based on plot size. Element symbols will be bold and 1.5x this size.bg_color
(str): Plot background color. Defaults to "rgba(0, 0, 0, 0)".colorbar
(dict[str, Any]): Plotly colorbar properties documented athttps
: //plotly.com/python/reference#heatmap-colorbar. Defaults to dict(orientation="h"). Commonly used keys are: - title: colorbar title - titleside: "top" | "bottom" | "right" | "left" - tickmode: "array" | "auto" | "linear" | "log" | "date" | "category" - tickvals: list of tick values - ticktext: list of tick labels - tickformat: f-string format option for tick labels - len: fraction of plot height or width depending on orientation - thickness: fraction of plot height or width depending on orientationnan_color
(str): Fill color for element tiles with NaN values. Defaults to "#eff".cscale_range
(tuple[float | None, float | None]): Colorbar range. Defaults to (None, None) meaning the range is automatically determined from the data.exclude_elements
(list[str]): Elements to exclude from the heatmap. E.g. if oxygen overpowers everything, you can do exclude_elements=["O"]. Defaults to ().log
(bool): Whether to use a logarithmic color scale. Defaults to False.Piece of advice
: colorscale="viridis" and log=True go well together.fill_value
(float | None): Value to fill in for missing elements. Defaults to 0.element_symbol_map
(dict[str, str] | None): A dictionary to map element symbols to custom strings. If provided, these custom strings will be displayed instead of the standard element symbols. Defaults to None.label_map
(dict[str, str] | Callable[[str], str] | None): Map heat values (after string formatting) to target strings. Set to False to disable. Defaults to dict.fromkeys((np.nan, None, "nan", "nan%"), "-") so as not to display "nan" for missing values.border
(dict[str, Any]): Border properties for element tiles. Defaults to dict(width=1, color="gray"). Other allowed keys are arguments of go.Heatmap which is (mis-)used to draw the borders as a 2nd heatmap below the main one. Pass False to disable borders.scale
(float): Scaling factor for whole figure layout. Defaults to 1.**kwargs
: Additional keyword arguments passed to plotly.figure_factory.create_annotated_heatmap().
Returns:
Figure
: Plotly Figure object.
function ptable_hists_plotly
ptable_hists_plotly(
data: 'DataFrame | Series | dict[str, list[float]]',
bins: 'int' = 20,
x_range: 'tuple[float | None, float | None] | None' = None,
log: 'bool' = False,
colorscale: 'str' = 'RdBu',
colorbar: 'dict[str, Any] | Literal[False] | None' = None,
font_size: 'int | None' = None,
scale: 'float' = 1.0,
element_symbol_map: 'dict[str, str] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None,
annotations: 'dict[str, str | dict[str, Any]] | Callable[[Sequence[float]], str | dict[str, Any] | list[dict[str, Any]]] | None' = None,
color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
elem_type_colors: 'dict[str, str] | None' = None,
subplot_kwargs: 'dict[str, Any] | None' = None,
x_axis_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plotly figure with histograms for each element laid out in a periodic table.
Args:
-
data
(pd.DataFrame | pd.Series | dict[str, list[float]]): Map from element -
symbols to histogram values. E.g. if dict, {"Fe"
: [1, 2, 3], "O": [4, 5]}. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols histograms are plotted from each column.--- Histogram-specific ---
-
bins
(int): Number of bins for the histograms. Defaults to 20. -
x_range
(tuple[float | None, float | None]): x-axis range for all histograms. Defaults to None. -
log
(bool): Whether to log scale y-axis of each histogram. Defaults to False. -
colorscale
(str): Color scale for histogram bars. Defaults to "RdBu" (red to blue). See plotly.com/python/builtin-colorscales for other options. -
colorbar
(dict[str, Any] | None): Plotly colorbar properties. Defaults to -
dict(orientation="h"). See https
: //plotly.com/python/reference#heatmap-colorbar for available options. Set to False to hide the colorbar.--- Layout ---
-
font_size
(int): Element symbol and annotation text size. Defaults to automatic font size based on plot size. -
scale
(float): Scaling factor for whole figure layout. Defaults to 1.--- Text ---
-
element_symbol_map
(dict[str, str] | None): A dictionary to map element symbols to custom strings. If provided, these custom strings will be displayed instead of the standard element symbols. Defaults to None. -
symbol_kwargs
(dict): Additional keyword arguments for element symbol text. -
annotations
(dict[str, str] | Callable[[np.ndarray], str] | None): Annotation to display for each element tile. Can be either: - dict mapping element symbols to annotation strings - callable that takes histogram values and returns annotation string - None for not displaying annotations (default)--- Element type colors ---
-
color_elem_strategy
("symbol" | "background" | "both" | "off"): Whether to color element symbols, tile backgrounds, or both based on element type. Defaults to "background". -
elem_type_colors
(dict | None): dict to map element types to colors. None to use the default = pymatviz.colors.ELEM_TYPE_COLORS. -
subplot_kwargs
(dict | None): Additional keywords passed to make_subplots(). Can be used e.g. to toggle shared x/y-axes. -
x_axis_kwargs
(dict | None): Additional keywords for x-axis like tickfont, showticklabels, nticks, tickformat, tickangle.
Returns:
go.Figure
: Plotly Figure object with histograms in a periodic table layout.
function ptable_heatmap_splits_plotly
ptable_heatmap_splits_plotly(
data: 'DataFrame | Series | dict[str, list[float]]',
orientation: "Literal['diagonal', 'horizontal', 'vertical', 'grid']" = 'diagonal',
colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
colorbar: 'dict[str, Any] | Literal[False] | None' = None,
on_empty: "Literal['hide', 'show']" = 'hide',
hide_f_block: "bool | Literal['auto']" = 'auto',
font_size: 'int | None' = None,
scale: 'float' = 1.0,
element_symbol_map: 'dict[str, str] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None,
annotations: 'dict[str, str | dict[str, Any]] | Callable[[ndarray], str | dict[str, Any]] | None' = None,
nan_color: 'str' = '#eff',
hover_data: 'dict[str, str | int | float] | Series | None' = None,
subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure
Create a Plotly figure with an interactive heatmap of the periodic table, where each element tile is split into sections representing different values.
Args:
-
data
(pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element -
symbols to plot data. E.g. if dict, {"Fe"
: [1, 2], "Co": [3, 4]}, where the 1st value would be plotted in lower-left corner, 2nd in the upper-right.--- Figure ---
-
colorscale
(str | list[str] | list[tuple[float, str]]): Color scale for heatmap. Defaults to "viridis". -
colorbar
(dict[str, Any] | None): Plotly colorbar properties. Defaults to -
dict(orientation="h"). See https
: //plotly.com/python/reference#heatmap-colorbar for available options. Set to False to hide the colorbar. -
on_empty
("hide" | "show"): Whether to show tiles for elements without data. Defaults to "hide". -
hide_f_block
(bool | "auto"): Hide f-block (lanthanide and actinide series). Defaults to "auto", meaning hide if no data present. -
orientation
(str): How to split each element tile. Defaults to "diagonal". - "diagonal": Split at 45° angles - "horizontal": Split into equal horizontal strips - "vertical": Split into equal vertical strips - "grid": Split into 2x2 grid (only valid for n_splits=4)--- Layout ---
-
font_size
(int): Element symbol and annotation text size. Defaults to automatic font size based on plot size. -
scale
(float): Scaling factor for whole figure layout. Defaults to 1.--- Symbol ---
-
element_symbol_map
(dict[str, str] | None): A dictionary to map element symbols to custom strings. If provided, these custom strings will be displayed instead of the standard element symbols. Defaults to None. -
symbol_kwargs
(dict): Additional keyword arguments for element symbol text.--- Annotation ---
-
annotations
(dict[str, str] | Callable[[np.ndarray], str] | None): Annotation to display for each element tile. Can be either: - dict mapping element symbols to annotation strings - callable that takes values and returns annotation string - None for not displaying annotations (default)--- Additional options ---
-
nan_color
(str): Color for NaN values. Defaults to "#eff". -
hover_data
(dict[str, str | int | float] | pd.Series): Additional data for hover tooltip. -
subplot_kwargs
(dict | None): Additional keywords passed to make_subplots().
Returns:
go.Figure
: Plotly Figure object with the periodic table heatmap splits.
Raises:
ValueError
: If n_splits not in {2, 3, 4} or orientation="grid" with n_splits!=4
module rainclouds
Raincloud plots.
Global Variables
- TYPE_CHECKING
function rainclouds
rainclouds(
data: 'dict[str, Sequence[float] | tuple[DataFrame, str]]',
orientation: "Literal['h', 'v']" = 'h',
alpha: 'float' = 0.7,
width_viol: 'float' = 0.3,
width_box: 'float' = 0.05,
jitter: 'float' = 0.01,
point_size: 'float' = 3,
bw: 'float' = 0.2,
cut: 'float' = 0.0,
scale: "Literal['area', 'count', 'width']" = 'area',
rain_offset: 'float' = -0.25,
offset: 'float | None' = None,
hover_data: 'Sequence[str] | dict[str, Sequence[str]] | None' = None,
show_violin: 'bool' = True,
show_box: 'bool' = True,
show_points: 'bool' = True
) → Figure
Create a raincloud plot for multiple datasets using Plotly.
This plot type was proposed in https://wellcomeopenresearch.org/articles/4-63/v2. It is a vertical stack of:
- violin plot (the cloud) 2. box plot (the umbrella) 3. strip plot (the rain)
Args:
data
(dict[str, Union[Sequence[float], tuple[pd.DataFrame, str]]]): A dictionary where keys are labels and values are either sequences of float data or tuples containing a DataFrame and the column name to plot. Dataframes can hold additional columns to be used in hover tooltips.orientation
("h" | "v", optional): Orientation of the plot. "h" for horizontal, "v" for vertical. Defaults to "h".alpha
(float, optional): Transparency of the violin plots. Defaults to 0.7.width_viol
(float, optional): Width of the violin plots. Defaults to 0.3.width_box
(float, optional): Width of the box plots. Defaults to 0.05.jitter
(float, optional): Amount of jitter for the strip plot. Defaults to 0.01.point_size
(float, optional): Size of the points in the strip plot. Defaults to 3.bw
(float, optional): Bandwidth for the KDE. Defaults to 0.2.cut
(float, optional): Distance past extreme data points to extend KDE. Defaults to 0.0.scale
("area" | "count" | "width", optional): Method to scale the width of each violin. Defaults to "area".rain_offset
(float, optional): Shift the strip plot position. Defaults to -0.25.offset
(float | None, optional): Shift the violin plot position. Defaults to None. hover_data (Sequence[str] | dict[str, Sequence[str]] | None, optional): Additional data to be shown in hover tooltips. Can be a list of column names or a dict with the same keys as data and different column names for each trace.show_violin
(bool, optional): Whether to show the violin plot. Defaults to True.show_box
(bool, optional): Whether to show the box plot. Defaults to True.show_points
(bool, optional): Whether to show the strip plot points. Defaults to True.**kwargs
: Additional keyword arguments to pass to the plotting functions.
Returns:
go.Figure
: The Plotly figure containing the raincloud plot.
module rdf.helpers
Helper functions for radial distribution functions (RDFs) of pymatgen structures.
Global Variables
- TYPE_CHECKING
function calculate_rdf
calculate_rdf(
structure: 'Structure',
center_species: 'str | None' = None,
neighbor_species: 'str | None' = None,
cutoff: 'float' = 15,
n_bins: 'int' = 75,
pbc: 'PbcLike' = (True, True, True)
) → tuple[ndarray, ndarray]
Calculate the radial distribution function (RDF) for a given structure.
If center_species and neighbor_species are provided, calculates the partial RDF for the specified element pair. Otherwise, calculates the full RDF.
Args:
structure
(Structure): A pymatgen Structure object.center_species
(str, optional): Symbol of the central species. If None, all species are considered.neighbor_species
(str, optional): Symbol of the neighbor species. If None, all species are considered.cutoff
(float, optional): Maximum distance for RDF calculation. Default is 15 Å.n_bins
(int, optional): Number of bins for RDF calculation. Default is 75.pbc
(tuple[int, int, int], optional): Periodic boundary conditions as any 3-tuple of 0s/1s. Defaults to (1, 1, 1).
Returns:
tuple[np.ndarray, np.ndarray]
: Arrays of (radii, g(r)) values.
module rdf
Radial distribution function plots (RDFs).
module rdf.plotly
Radial distribution functions (RDFs) of pymatgen structures using plotly.
The main function, pairwise_rdfs, generates a plotly figure with facets for each pair of elements in the given structure. It supports customization of cutoff distance, bin size, specific element pairs to plot, reference line.
Example usage: structure = Structure(...) # Create or load a pymatgen Structure fig = pairwise_rdfs(structure, bin_size=0.1) fig.show()
Global Variables
- TYPE_CHECKING
function element_pair_rdfs
element_pair_rdfs(
structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
cutoff: 'float | None' = None,
n_bins: 'int' = 75,
bin_size: 'float | None' = None,
element_pairs: 'list[tuple[str, str]] | None' = None,
reference_line: 'dict[str, Any] | None' = None,
colors: 'Sequence[str] | None' = None,
line_styles: 'Sequence[str] | None' = None,
subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure
Generate a plotly figure of pairwise radial distribution functions (RDFs) for all (or a subset of) element pairs in one or multiple structures.
Args:
structures
: Can be one of the following: - single pymatgen Structure - list of pymatgen Structures - dictionary mapping labels to Structurescutoff
(float | None, optional): Maximum distance for RDF calculation. If None, defaults to twice the longest lattice vector length across all structures (up to 15A). If negative, its absolute value is used as a scaling factor for the longest lattice vector length (e.g. -1.5 means 1.5x the longest lattice vector). Default is None.n_bins
(int, optional): Number of bins for RDF calculation. Default is 75.bin_size
(float, optional): Size of bins for RDF calculation. If specified, it overrides n_bins. Default is None.element_pairs
(list[tuple[str, str]], optional): Element pairs to plot. If None, all pairs present in any structure are plotted.reference_line
(dict, optional): Keywords for reference line at g(r)=1 drawn with Figure.add_hline(). If None (default), no reference line is drawn.colors
(Sequence[str], optional): colors for each structure's RDF line. Defaults to plotly.colors.qualitative.Plotly.line_styles
(Sequence[str], optional): line styles for each structure's RDF line. Will be used for all element pairs present in that structure. Defaults to ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"].subplot_kwargs
(dict, optional): Passed to plotly.make_subplots. Use this to e.g. set subplot_titles, rows/cols or row/column spacing to customize the subplot layout.
Returns:
go.Figure
: A plotly figure with facets for each pairwise RDF, comparing one or multiple structures.
Raises:
ValueError
: If no structures are provided, if structures have no sites, if invalid element pairs are provided, or if both n_bins and bin_size are specified.
function full_rdf
full_rdf(
structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
cutoff: 'float' = 15,
n_bins: 'int' = 75,
bin_size: 'float | None' = None,
reference_line: 'dict[str, Any] | None' = None,
colors: 'Sequence[str] | None' = None,
line_styles: 'Sequence[str] | None' = None
) → Figure
Generate a plotly figure of full radial distribution functions (RDFs) for one or multiple structures.
Args:
structures
: Can be one of the following: - single pymatgen Structure - list of pymatgen Structures - dictionary mapping labels to Structurescutoff
(float, optional): Maximum distance for RDF calculation. Default is 15 Å.n_bins
(int, optional): Number of bins for RDF calculation. Default is 75.bin_size
(float, optional): Size of bins for RDF calculation. If specified, it overrides n_bins. Default is None.reference_line
(dict, optional): Keywords for reference line at g(r)=1 drawn with Figure.add_hline(). If None (default), no reference line is drawn.colors
(Sequence[str], optional): colors for each structure's RDF line. Defaults to plotly.colors.qualitative.Plotly.line_styles
(Sequence[str], optional): line styles for each structure's RDF line. Defaults to ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"].
Returns:
go.Figure
: A plotly figure with full RDFs for one or multiple structures.
Raises:
ValueError
: If no structures are provided, if structures have no sites, or if both n_bins and bin_size are specified.
module relevance
Plots for evaluating classifier performance.
Global Variables
- TYPE_CHECKING
function roc_curve
roc_curve(
targets: 'ArrayLike | str',
proba_pos: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None
) → tuple[float, Axes]
Plot the receiver operating characteristic curve of a binary classifier given target labels and predicted probabilities for the positive class.
Args:
targets
(array): Ground truth targets.proba_pos
(array): predicted probabilities for the positive class.df
(pd.DataFrame, optional): DataFrame with targets and proba_pos columns.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
tuple[float, ax]
: The classifier's ROC-AUC and the plot's matplotlib Axes.
function precision_recall_curve
precision_recall_curve(
targets: 'ArrayLike | str',
proba_pos: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None
) → tuple[float, Axes]
Plot the precision recall curve of a binary classifier.
Args:
targets
(array): Ground truth targets.proba_pos
(array): predicted probabilities for the positive class.df
(pd.DataFrame, optional): DataFrame with targets and proba_pos columns.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
tuple[float, ax]
: The classifier's precision score and the matplotlib Axes.
module sankey
Sankey diagram for comparing distributions in two dataframe columns.
Global Variables
- TYPE_CHECKING
function sankey_from_2_df_cols
sankey_from_2_df_cols(
df: 'DataFrame',
cols: 'list[str]',
labels_with_counts: "bool | Literal['percent']" = True,
annotate_columns: 'bool | dict[str, Any]' = True,
**kwargs: 'Any'
) → Figure
Plot two columns of a dataframe as a Plotly Sankey diagram.
Args:
df
(pd.DataFrame): Pandas dataframe.cols
(list[str]): 2-tuple of source and target column names. Source corresponds to left, target to right side of the diagram.labels_with_counts
(bool, optional): Whether to append value counts to node labels. Defaults to True.annotate_columns
(bool, dict[str, Any], optional): Whether to use the column names as annotations vertically centered on the left and right sides of the diagram. If a dict, passed as **kwargs to plotly.graph_objects.Figure.add_annotation. Defaults to True.**kwargs
: Additional keyword arguments passed to plotly.graph_objects.Sankey.
Raises:
ValueError
: If len(cols) != 2.
Returns:
Figure
: Plotly figure containing the Sankey diagram.
module scatter
Parity, residual and density plots.
Global Variables
- TYPE_CHECKING
function density_scatter
density_scatter(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
log_density: 'bool' = True,
hist_density_kwargs: 'dict[str, Any] | None' = None,
color_bar: 'bool | dict[str, Any]' = True,
xlabel: 'str | None' = None,
ylabel: 'str | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
stats: 'bool | dict[str, Any]' = True,
**kwargs: 'Any'
) → Axes
Scatter plot colored by density using matplotlib backend.
Args:
x
(array | str): x-values or dataframe column name.y
(array | str): y-values or dataframe column name.df
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.log_density
(bool, optional): Whether to log the density color scale. Defaults to True.hist_density_kwargs
(dict, optional): Passed to hist_density(). Use to change sort (by density, default True), bins (default 100), or method (for interpolation, default "nearest"). matplotlib backend only.color_bar
(bool | dict, optional): Whether to add a color bar. Defaults to True. If dict, unpacked into ax.figure.colorbar(). E.g. dict(label="Density").xlabel
(str, optional): x-axis label. Defaults to "Actual".ylabel
(str, optional): y-axis label. Defaults to "Predicted".identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.stats
(bool | dict[str, Any], optional): Whether to display a text box with MAE and R^2. Defaults to True. Can be dict to pass kwargs to annotate_metrics(). E.g. stats=dict(loc="upper left", prefix="Title", prop=dict(fontsize=16)).**kwargs
: Passed to ax.scatter().
Returns:
plt.Axes
: The plot object.
function density_scatter_plotly
density_scatter_plotly(
df: 'DataFrame',
x: 'str',
y: 'str',
density: "Literal['kde', 'empirical'] | None" = None,
log_density: 'bool | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any] | None' = None,
stats: 'bool | dict[str, Any]' = True,
n_bins: 'int | None | Literal[False]' = None,
bin_counts_col: 'str | None' = None,
facet_col: 'str | None' = None,
**kwargs: 'Any'
) → Figure
Scatter plot colored by density using plotly backend.
This function uses binning as implemented in bin_df_cols() to reduce the number of points plotted which enables plotting millions of data points and reduced file size for interactive plots. All outlier points will be plotted as is but overlapping points (tolerance for overlap determined by n_bins) will be merged into a single point with a new column bin_counts_col counting the number of points in that bin.
Args:
x
(str): x-values dataframe column name.y
(str): y-values dataframe column name.df
(pd.DataFrame): DataFrame with x and y columns.density
('kde' | 'interpolate' | 'empirical'): Determines the method for calculating and displaying density.log_density
(bool | None): Whether to apply logarithmic scaling to density. If None, automatically set based on density range.identity_line
(bool | dict[str, Any], optional): Whether to add a parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.stats
(bool | dict[str, Any], optional): Whether to display a text box with MAE and R^2. Defaults to True. Can be dict to pass kwargs to annotate_metrics(). E.g. stats=dict(loc="upper left", prefix="Title", font=dict(size=16)).n_bins
(int | None | False, optional): Number of bins for histogram. If None, automatically enables binning mode if the number of datapoints exceeds 1000, else defaults to False (no binning). If int, uses that number of bins. If False, performs no binning. Defaults to None.bin_counts_col
(str, optional): Column name for bin counts. Defaults to "Point Density". Will be used as color bar title.facet_col
(str | None, optional): Column name to use for creating faceted subplots. If provided, the plot will be split into multiple subplots based on unique values in this column. Defaults to None.**kwargs
: Passed to px.scatter().
Returns:
go.Figure
: The plot object.
function scatter_with_err_bar
scatter_with_err_bar(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
xerr: 'ArrayLike | None' = None,
yerr: 'ArrayLike | None' = None,
ax: 'Axes | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
xlabel: 'str' = 'Actual',
ylabel: 'str' = 'Predicted',
title: 'str | None' = None,
**kwargs: 'Any'
) → Axes
Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty.
Args:
x
(array | str): x-values or dataframe column namey
(array | str): y-values or dataframe column namedf
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.xerr
(array, optional): Horizontal error bars. Defaults to None.yerr
(array, optional): Vertical error bars. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.xlabel
(str, optional): x-axis label. Defaults to "Actual".ylabel
(str, optional): y-axis label. Defaults to "Predicted".title
(str, optional): Plot tile. Defaults to None.**kwargs
: Additional keyword arguments to pass to ax.errorbar().
Returns:
plt.Axes
: matplotlib Axes object
function density_hexbin
density_hexbin(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
weights: 'ArrayLike | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
xlabel: 'str' = 'Actual',
ylabel: 'str' = 'Predicted',
cbar_label: 'str | None' = 'Density',
cbar_coords: 'tuple[float, float, float, float]' = (0.95, 0.03, 0.03, 0.7),
**kwargs: 'Any'
) → Axes
Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights.
Args:
x
(array): x-values or dataframe column name.y
(array): y-values or dataframe column name.df
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.weights
(array, optional): If given, these values are accumulated in the bins. Otherwise, every point has value 1. Must be of the same length as x and y. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.xlabel
(str, optional): x-axis label. Defaults to "Actual".ylabel
(str, optional): y-axis label. Defaults to "Predicted".cbar_label
(str, optional): Color bar label. Defaults to "Density".cbar_coords
(tuple[float, float, float, float], optional): Color bar positionand size
: [x, y, width, height] anchored at lower left corner. Defaults to (0.18, 0.8, 0.42, 0.05).**kwargs
: Additional keyword arguments to pass to ax.hexbin().
Returns:
plt.Axes
: matplotlib Axes object
function density_scatter_with_hist
density_scatter_with_hist(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Scatter plot colored (and optionally sorted) by density with histograms along each dimension.
function density_hexbin_with_hist
density_hexbin_with_hist(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
**kwargs: 'Any'
) → Axes
Hexagonal-grid scatter plot colored by density or by third dimension passed color_by with histograms along each dimension.
function residual_vs_actual
residual_vs_actual(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
xlabel: 'str' = 'Actual value',
ylabel: 'str' = 'Residual ($y_\\mathrm{true} - y_\\mathrm{pred}$)',
**kwargs: 'Any'
) → Axes
Plot targets on the x-axis vs residuals (y_err = y_true - y_pred) on the y-axis.
Args:
y_true
(array): Ground truth valuesy_pred
(array): Model predictionsdf
(pd.DataFrame, optional): DataFrame with y_true and y_pred columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.xlabel
(str, optional): x-axis label. Defaults to "Actual value".ylabel
(str, optional): y-axis label. Defaults to'Residual ($y_\mathrm{true} - y_\mathrm{pred}$)'
.**kwargs
: Additional keyword arguments passed to plt.plot()
Returns:
plt.Axes
: matplotlib Axes object
module structure_viz.helpers
Helper functions for 2D and 3D plots of pymatgen structures with plotly.
Global Variables
- TYPE_CHECKING
- ELEM_COLORS_JMOL
- ELEM_COLORS_VESTA
- missing_covalent_radius
- NO_SYM_MSG
- UNIT_CELL_EDGES
function get_image_sites
get_image_sites(
site: 'PeriodicSite',
lattice: 'Lattice',
tol: 'float' = 0.02
) → ndarray
Get images for a given site in a lattice.
Images are sites that are integer translations of the given site that are within a tolerance of the unit cell edges.
Args:
site
(PeriodicSite): The site to get images for.lattice
(Lattice): The lattice to get images for.tol
(float): The tolerance for being on the unit cell edge. Defaults to 0.02.
Returns:
np.ndarray
: Coordinates of all image sites.
function unit_cell_to_lines
unit_cell_to_lines(cell: 'ArrayLike') → tuple[ArrayLike, ArrayLike, ArrayLike]
Convert lattice vectors to plot lines.
Args:
cell
(np.array): Lattice vectors.
Returns: tuple[np.array, np.array, np.array]: - Lines - z-indices that sort plot elements into out-of-plane layers - lines used to plot the unit cell
function get_elem_colors
get_elem_colors(
elem_colors: 'ElemColorScheme | dict[str, str]'
) → dict[str, str]
Get element colors based on the provided scheme or custom dictionary.
function get_atomic_radii
get_atomic_radii(
atomic_radii: 'float | dict[str, float] | None'
) → dict[str, float]
Get atomic radii based on the provided input.
function generate_site_label
generate_site_label(
site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]",
site_idx: 'int',
majority_species: 'Species'
) → str
Generate a label for a site based on the provided labeling scheme.
function get_subplot_title
get_subplot_title(
struct_i: 'Structure',
struct_key: 'Any',
idx: 'int',
subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None'
) → dict[str, Any]
Generate a subplot title based on the provided function or default logic.
function get_site_hover_text
get_site_hover_text(
site: 'PeriodicSite',
hover_text: 'SiteCoords | Callable[[PeriodicSite], str]',
majority_species: 'Species'
) → str
Generate hover text for a site based on the hover template.
function draw_site
draw_site(
fig: 'Figure',
site: 'PeriodicSite',
coords: 'ndarray',
site_idx: 'int',
site_labels: 'Any',
_elem_colors: 'dict[str, str]',
_atomic_radii: 'dict[str, float]',
atom_size: 'float',
scale: 'float',
site_kwargs: 'dict[str, Any]',
is_image: 'bool' = False,
is_3d: 'bool' = False,
row: 'int | None' = None,
col: 'int | None' = None,
scene: 'str | None' = None,
hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
**kwargs: 'Any'
) → None
Add a site (regular or image) to the plot.
function get_structures
get_structures(
struct: 'Structure | Sequence[Structure] | Series | dict[Any, Structure]'
) → dict[Any, Structure]
Convert various input types to a dictionary of structures.
function draw_unit_cell
draw_unit_cell(
fig: 'Figure',
structure: 'Structure',
unit_cell_kwargs: 'dict[str, Any]',
is_3d: 'bool' = True,
row: 'int | None' = None,
col: 'int | None' = None,
scene: 'str | None' = None
) → Figure
Draw the unit cell of a structure in a 2D or 3D Plotly figure.
function draw_vector
draw_vector(
fig: 'Figure',
start: 'ndarray',
vector: 'ndarray',
is_3d: 'bool' = False,
arrow_kwargs: 'dict[str, Any] | None' = None,
**kwargs: 'Any'
) → None
Add an arrow to represent a vector quantity on a Plotly figure.
This function adds an arrow to a 2D or 3D Plotly figure to represent a vector quantity. In 3D, it uses a cone for the arrowhead and a line for the shaft. In 2D, it uses a scatter plot with an arrow marker.
Args:
fig
(go.Figure): The Plotly figure to add the arrow to.start
(np.ndarray): The starting point of the arrow (shape: (3,) for 3D, (2,) for 2D).vector
(np.ndarray): The vector to be represented by the arrow (shape: (3,) for 3D, (2,) for 2D).is_3d
(bool, optional): Whether to add a 3D arrow. Defaults to False.arrow_kwargs
(dict[str, Any] | None, optional): Additional keyword arguments for arrow customization. Supported keys: - color (str): Color of the arrow. - width (float): Width of the arrow shaft. - arrow_head_length (float): Length of the arrowhead (3D only). - arrow_head_angle (float): Angle of the arrowhead in degrees (3D only). - scale (float): Scaling factor for the vector length.**kwargs
: Additional keyword arguments passed to the Plotly trace.
Note:
For 3D arrows, this function adds two traces to the figure: a cone for the arrowhead and a line for the shaft. For 2D arrows, it adds a single scatter trace with an arrow marker.
function get_first_matching_site_prop
get_first_matching_site_prop(
structures: 'Sequence[Structure]',
prop_keys: 'Sequence[str]',
warn_if_none: 'bool' = True,
filter_callback: 'Callable[[str, Any], bool] | None' = None
) → str | None
Find the first property key that exists in any of the passed structures' properties or site properties. Will look in site.properties first, then structure.properties.
Args:
structures
(Sequence[Structure]): pymatgen Structures to check.prop_keys
(Sequence[str]): Property keys to look for.warn_if_none
(bool, optional): Whether to warn if no matching property is found.filter_callback
(Callable[[str, Any], bool] | None, optional): A function that takes the property key and value, and returns True if the property should be considered a match. If None, all properties are considered matches.
Returns:
str | None
: The first matching property key found, or None if no match is found.
function draw_bonds
draw_bonds(
fig: 'Figure',
structure: 'Structure',
nn: 'NearNeighbors',
is_3d: 'bool' = True,
bond_kwargs: 'dict[str, Any] | None' = None,
row: 'int | None' = None,
col: 'int | None' = None,
scene: 'str | None' = None,
visible_image_atoms: 'set[tuple[float, float, float]] | None' = None
) → None
Draw bonds between atoms in the structure.
module structure_viz
2D and 3D plots of Structures.
module structure_viz.mpl
2D plots of pymatgen structures with matplotlib.
structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.
Global Variables
- TYPE_CHECKING
- ELEM_COLORS_JMOL
- ELEM_COLORS_VESTA
- NO_SYM_MSG
- MISSING_COVALENT_RADIUS
function structure_2d
structure_2d(
struct: 'Structure | Sequence[Structure]',
ax: 'Axes | None' = None,
rotation: 'str' = '10x,8y,3z',
atomic_radii: 'float | dict[str, float] | None' = None,
elem_colors: 'ElemColorScheme | dict[str, str | ColorType]' = Jmol,
scale: 'float' = 1,
show_unit_cell: 'bool' = True,
show_bonds: 'bool | NearNeighbors' = False,
site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
label_kwargs: 'dict[str, Any] | None' = None,
bond_kwargs: 'dict[str, Any] | None' = None,
standardize_struct: 'bool | None' = None,
axis: 'bool | str' = 'off',
n_cols: 'int' = 4,
subplot_kwargs: 'dict[str, Any] | None' = None,
subplot_title: 'Callable[[Structure, str | int], str] | None' = None
) → Axes | tuple[Figure, ndarray[Axes]]
Plot pymatgen structures in 2D with matplotlib.
structure_2d is not deprecated but structure_(2d|3d)_plotly() have more features and are recommended replacements.
Inspired by ASE's ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib pymatviz aims to give similar output to ASE but supports disordered structures and avoids the conversion hassle of AseAtomsAdaptor.get_atoms(pmg_struct).
For example, these two snippets should give very similar output:
from pymatgen.ext.matproj import MPRester
mp_19017 = MPRester().get_structure_by_material_id("mp-19017")
# ASE
from ase.visualize.plot import plot_atoms
from pymatgen.io.ase import AseAtomsAdaptor
plot_atoms(AseAtomsAdaptor().get_atoms(mp_19017), rotation="10x,8y,3z", radii=0.5)
# pymatviz
from pymatviz import structure_2d
structure_2d(mp_19017)
Multiple structures in single figure example:
from pymatgen.ext.matproj import MPRester
from pymatviz import structure_2d
structures = {
(mp_id := f"mp-{idx}"): MPRester().get_structure_by_material_id(mp_id)
for idx in range(1, 5)
}
structure_2d(structures)
Args:
struct
(Structure): Must be pymatgen instance.ax
(plt.Axes, optional): Matplotlib axes on which to plot. Defaults to None.rotation
(str, optional): Euler angles in degrees in the form '10x,20y,30z' describing angle at which to view structure. Defaults to "".atomic_radii
(float | dict[str, float], optional): Either a scaling factor for default radii or map from element symbol to atomic radii. Defaults to covalent radii.elem_colors
(dict[str, str | list[float]], optional): Map from element symbols to colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6).Defaults to JMol colors (https
: //jmol.sourceforge.net/jscolors).scale
(float, optional): Scaling of the plotted atoms and lines. Defaults to 1.show_unit_cell
(bool, optional): Whether to plot unit cell. Defaults to True.show_bonds
(bool | NearNeighbors, optional): Whether to plot bonds. If True, use pymatgen.analysis.local_env.CrystalNN to infer the structure's connectivity. If False, don't plot bonds. If a subclass of pymatgen.analysis.local_env.NearNeighbors, use that to determine connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN, CovalentBondNN, dtc. Defaults to True. site_labels ("symbol" | "species" | False | dict[str, str] | Sequence): How to annotate lattice sites. If True, labels are element species (symbol + oxidation state). If a dict, should map species strings (or element symbols but looks for species string first) to labels. If a list, must be same length as the number of sites in the crystal. If a string, must be "symbol" or "species". "symbol" hides the oxidation state, "species" shows it (equivalent to True). Defaults to "species".label_kwargs
(dict, optional): Keyword arguments for matplotlib.text.Text like{"fontsize"
: 14}. Defaults to None.bond_kwargs
(dict, optional): Keyword arguments for the matplotlib.path.Path class used to plot chemical bonds. Allowed are edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle. Defaults to None.standardize_struct
(bool, optional): Whether to standardize the structure using SpacegroupAnalyzer(struct).get_conventional_standard_structure() before plotting. Defaults to False unless any fractional coordinates are negative, i.e. any crystal sites are outside the unit cell. Set this to False to disable this behavior which speeds up plotting for many structures.axis
(bool | str, optional): Whether/how to show plot axes. Defaults to "off".See https
: //matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.axis for details.n_cols
(int, optional): Number of columns for subplots. Defaults to 4.subplot_kwargs
(dict, optional): Unused if only a single structure is passed. Keyword arguments for plt.subplots(). Defaults to None. Use this to specify figsize and other figure customization.subplot_title
(Callable[[Structure, str | int], str], optional): Should return subplot titles. Receives the structure and its key or index when passed as a dict or pandas.Series. Defaults to None in which case the title is the structure's material id if available, otherwise its formula and space group.
Raises:
ValueError
: On invalid site_labels.
Returns:
plt.Axes | tuple[plt.Figure, np.ndarray[plt.Axes]]
: Axes instance with plotted structure. If multiple structures are passed, returns both the parent Figure and its Axes.
module structure_viz.plotly
Create interactive hoverable 2D and 3D plots of pymatgen structures with plotly.
Global Variables
- TYPE_CHECKING
- NO_SYM_MSG
function structure_2d_plotly
structure_2d_plotly(
struct: 'Structure | Sequence[Structure]',
rotation: 'str' = '10x,8y,3z',
atomic_radii: 'float | dict[str, float] | None' = None,
atom_size: 'float' = 40,
elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
scale: 'float' = 1,
show_unit_cell: 'bool | dict[str, Any]' = True,
show_sites: 'bool | dict[str, Any]' = True,
show_image_sites: 'bool | dict[str, Any]' = True,
show_bonds: 'bool | NearNeighbors' = False,
site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
standardize_struct: 'bool | None' = None,
n_cols: 'int' = 3,
subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None' = None,
show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
bond_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plot pymatgen structures in 2D with Plotly.
Args:
struct
(Structure | Sequence[Structure]): Pymatgen Structure(s) to plot.rotation
(str, optional): Euler angles in degrees in the form '10x,20y,30z' describing angle at which to view structure. Defaults to "10x,8y,3z".atomic_radii
(float | dict[str, float], optional): Either a scaling factor for default radii or map from element symbol to atomic radii. Defaults to None.atom_size
(float, optional): Scaling factor for atom sizes. Defaults to 40.elem_colors
(ElemColorScheme | dict[str, str], optional): Element color scheme or custom color map. Defaults to ElemColorScheme.jmol.scale
(float, optional): Scaling of the plotted atoms and lines. Defaults to 1.show_unit_cell
(bool | dict[str, Any], optional): Whether to plot unit cell. If a dict, will be used to customize unit cell appearance. The dict should have a "node"/"edge" key to customize node/edge appearance. Defaults to True.show_sites
(bool | dict[str, Any], optional): Whether to plot atomic sites. If a dict, will be used to customize site marker appearance. Defaults to True.show_image_sites
(bool | dict[str, Any], optional): Whether to show image sites on unit cell edges and surfaces. If a dict, will be used to customize how image sites are rendered. Defaults to True.show_bonds
(bool | NearNeighbors, optional): Whether to draw bonds between sites. If True, uses CrystalNN to determine nearest neighbors. If a NearNeighbors object, uses that to determine nearest neighbors. Defaults to False (since still experimental). site_labels ("symbol" | "species" | dict[str, str] | Sequence): How to annotate lattice sites. Defaults to "species".standardize_struct
(bool, optional): Whether to standardize the structure. Defaults to None.n_cols
(int, optional): Number of columns for subplots. Defaults to 4. subplot_title (Callable[[Structure, str | int], str | dict] | False, optional): Function to generate subplot titles. Defaults tolambda struct_i, idx
: f"{idx}. {struct_i.formula} (spg={spg_num})". Set to False to hide all subplot titles.show_site_vectors
(str | Sequence[str], optional): Whether to show vector site quantities such as forces or magnetic moments as arrow heads originating from each site. Pass the key (or sequence of keys) to look for in site properties. Defaults to ("force", "magmom"). If not found as a site property, will look for it in the structure properties as well and assume the key points at a (N, 3) array with N the number of sites. If multiple keys are provided, it plots the first key found in site properties or structure properties in any of the passed structures (if a dict of structures was passed). But it will only plot one vector per site and it will use the same key for all sites and across all structures.vector_kwargs
(dict[str, dict[str, Any]], optional): For customizing vector arrows. Keys are property names (e.g., "force", "magmom"), values are dictionaries of arrow customization options. Use key "scale" to adjust vector length.hover_text
(SiteCoords | Callable, optional): Controls the hover tooltip template. Can be SiteCoords.cartesian, SiteCoords.fractional, SiteCoords.cartesian_fractional, or a callable that takes a site and returns a custom string. Defaults to SiteCoords.cartesian_fractional.bond_kwargs
(dict[str, Any], optional): For customizing bond lines. Keys are line properties (e.g., "color", "width"), values are the corresponding values. Defaults to None.
Returns:
go.Figure
: Plotly figure with the plotted structure(s).
function structure_3d_plotly
structure_3d_plotly(
struct: 'Structure | Sequence[Structure]',
atomic_radii: 'float | dict[str, float] | None' = None,
atom_size: 'float' = 20,
elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
scale: 'float' = 1,
show_unit_cell: 'bool | dict[str, Any]' = True,
show_sites: 'bool | dict[str, Any]' = True,
show_image_sites: 'bool' = True,
show_bonds: 'bool | NearNeighbors' = False,
site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
standardize_struct: 'bool | None' = None,
n_cols: 'int' = 3,
subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None | Literal[False]' = None,
show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
bond_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plot pymatgen structures in 3D with Plotly.
Args:
struct
(Structure | Sequence[Structure]): Pymatgen Structure(s) to plot.atomic_radii
(float | dict[str, float], optional): Either a scaling factor for default radii or map from element symbol to atomic radii. Defaults to None.atom_size
(float, optional): Scaling factor for atom sizes. Defaults to 20.elem_colors
(ElemColorScheme | dict[str, str], optional): Element color scheme or custom color map. Defaults to ElemColorScheme.jmol.scale
(float, optional): Scaling of the plotted atoms and lines. Defaults to 1.show_unit_cell
(bool | dict[str, Any], optional): Whether to plot unit cell. If a dict, will be used to customize unit cell appearance. The dict should have a "node"/"edge" key to customize node/edge appearance. Defaults to True.show_sites
(bool | dict[str, Any], optional): Whether to plot atomic sites. If a dict, will be used to customize site marker appearance. Defaults to True.show_image_sites
(bool | dict[str, Any], optional): Whether to show image sites on unit cell edges and surfaces. If a dict, will be used to customize how image sites are rendered. Defaults to True.show_bonds
(bool | NearNeighbors, optional): Whether to draw bonds between sites. If True, uses CrystalNN to determine nearest neighbors. If a NearNeighbors object, uses that to determine nearest neighbors. Defaults to False (since still experimental). site_labels ("symbol" | "species" | dict[str, str] | Sequence): How to annotate lattice sites. Defaults to "species".standardize_struct
(bool, optional): Whether to standardize the structure. Defaults to None.n_cols
(int, optional): Number of columns for subplots. Defaults to 3. subplot_title (Callable[[Structure, str | int], str | dict] | False, optional): Function to generate subplot titles. Defaults tolambda struct_i, idx
: f"{idx}. {struct_i.formula} (spg={spg_num})". Set to False to hide all subplot titles.show_site_vectors
(str | Sequence[str], optional): Whether to show vector site quantities such as forces or magnetic moments as arrow heads originating from each site. Pass the key (or sequence of keys) to look for in site properties. Defaults to ("force", "magmom"). If not found as a site property, will look for it in the structure properties as well and assume the key points at a (N, 3) array with N the number of sites. If multiple keys are provided, it plots the first key found in site properties or structure properties in any of the passed structures (if a dict of structures was passed). But it will only plot one vector per site and it will use the same key for all sites and across all structures.vector_kwargs
(dict[str, dict[str, Any]], optional): For customizing vector arrows. Keys are property names (e.g., "force", "magmom"), values are dictionaries of arrow customization options. Use key "scale" to adjust vector length.hover_text
(SiteCoords | Callable, optional): Controls the hover tooltip template. Can be SiteCoords.cartesian, SiteCoords.fractional, SiteCoords.cartesian_fractional, or a callable that takes a site and returns a custom string. Defaults to SiteCoords.cartesian_fractional.bond_kwargs
(dict[str, Any], optional): For customizing bond lines. Keys are line properties (e.g., "color", "width"), values are the corresponding values. Defaults to None.
Returns:
go.Figure
: Plotly figure with the plotted 3D structure(s).
module sunburst
Hierarchical multi-level pie charts (i.e. sunbursts).
E.g. for crystal symmetry distributions.
Global Variables
- TYPE_CHECKING
function spacegroup_sunburst
spacegroup_sunburst(
data: 'Sequence[int | str] | Series',
show_counts: "Literal['value', 'percent', False]" = False,
**kwargs: 'Any'
) → Figure
Generate a sunburst plot with crystal systems as the inner ring for a list of international space group numbers.
Hint: To hide very small labels, set a uniformtext minsize and mode='hide'. fig.update_layout(uniformtext=dict(minsize=9, mode="hide"))
Args:
data
(list[int] | pd.Series): A sequence (list, tuple, pd.Series) of space group strings or numbers (from 1 - 230) or pymatgen structures.show_counts
("value" | "percent" | False): Whether to display values below each labels on the sunburst.color_discrete_sequence
(list[str]): A list of 7 colors, one for each crystal system. Defaults to plotly.express.colors.qualitative.G10.**kwargs
: Additional keyword arguments passed to plotly.express.sunburst.
Returns:
Figure
: The Plotly figure.
module templates
Define custom pymatviz templates (default styles) for plotly and matplotlib.
Global Variables
- TYPE_CHECKING
- PKG_NAME
- axis_template
- white_axis_template
- common_layout
- dark_axis_template
function set_plotly_template
set_plotly_template(
template: "Literal['pymatviz_white', 'pymatviz_dark'] | str | Template"
) → None
Set the default plotly express and graph objects template.
Args:
template
: Usually "pymatviz_white" or "pymatviz_dark" but any plotly.io.template name or the object itself is valid.
Raises:
ValueError
: If the template is not recognized.
module typing
Typing related: TypeAlias, generic types and so on.
Global Variables
- EXCLUDED_ATTRIBUTES
- TYPE_CHECKING
- BACKENDS
- MATPLOTLIB
- PLOTLY
- VALID_COLOR_ELEM_STRATEGIES
- VALID_FIG_TYPES
- VALID_FIG_NAMES
function runtime_checkable
runtime_checkable(cls)
Mark a protocol class as a runtime protocol.
Such protocol can be used with isinstance() and issubclass(). Raise TypeError if applied to a non-protocol class. This allows a simple-minded structural check very similar to one trick ponies in collections.abc such as Iterable.
For example:
@runtime_checkable class Closable(Protocol): def close(self): ...
assert isinstance(open('/some/file'), Closable)
Warning: this will check only the presence of the required methods, not their type signatures!
function cast
cast(typ, val)
Cast a value to a type.
This returns the value unchanged. To the type checker this signals that the return value has the designated type, but at runtime we intentionally don't check anything (we want this to be as fast as possible).
function assert_type
assert_type(val, typ)
Ask a static type checker to confirm that the value is of the given type.
At runtime this does nothing: it returns the first argument unchanged with no checks or side effects, no matter the actual type of the argument.
When a static type checker encounters a call to assert_type(), it emits an error if the value is not of the specified type:
def greet(name: str) -> None: assert_type(name, str) # OK assert_type(name, int) # type checker error
---
<a href="https://github.com/janosh/pymatviz/blob/main/typing/get_type_hints#L2319"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>
## <kbd>function</kbd> `get_type_hints`
```python
get_type_hints(obj, globalns=None, localns=None, include_extras=False)
Return type hints for an object.
This is often the same as obj.annotations, but it handles forward references encoded as string literals and recursively replaces all 'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
The argument may be a module, class, method, or function. The annotations are returned as a dictionary. For classes, annotations include also inherited members.
TypeError is raised if the argument is not of a type that can contain annotations, and an empty dictionary is returned if no annotations are present.
BEWARE -- the behavior of globalns and localns is counterintuitive (unless you are familiar with how eval() and exec() work). The search order is locals first, then globals.
-
If no dict arguments are passed, an attempt is made to use the globals from obj (or the respective module's globals for classes), and these are also used as the locals. If the object does not appear to have globals, an empty dictionary is used. For classes, the search order is globals first then locals.
-
If one dict argument is passed, it is used for both globals and locals.
-
If two dict arguments are passed, they specify globals and locals, respectively.
function get_origin
get_origin(tp)
Get the unsubscripted version of a type.
This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar, Annotated, and others. Return None for unsupported types.
Examples::
P = ParamSpec('P')
>>> assert get_origin(Literal[42]) is Literal
>>> assert get_origin(int) is None
>>> assert get_origin(ClassVar[int]) is ClassVar
>>> assert get_origin(Generic) is Generic
>>> assert get_origin(Generic[T]) is Generic
>>> assert get_origin(Union[T, int]) is Union
>>> assert get_origin(List[Tuple[T, T]][int]) is list
>>> assert get_origin(P.args) is P
function get_args
get_args(tp)
Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
T = TypeVar('T')
>>> assert get_args(Dict[str, int]) == (str, int)
>>> assert get_args(int) == ()
>>> assert get_args(Union[int, Union[T, int], str][int]) == (int, str)
>>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
>>> assert get_args(Callable[[], T][int]) == ([], int)
function is_typeddict
is_typeddict(tp)
Check if an annotation is a TypedDict class.
For example:
>>> from typing import TypedDict >>> class Film(TypedDict): ... title: str ... year: int ... >>> is_typeddict(Film) True >>> is_typeddict(dict) False
---
<a href="https://github.com/janosh/pymatviz/blob/main/typing/assert_never#L2520"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>
## <kbd>function</kbd> `assert_never`
```python
assert_never(arg: Never) → Never
Statically assert that a line of code is unreachable.
Example::
def int_or_str(arg: int | str) -> None: match arg: case int(): print("It's an int") case str(): print("It's a str") case _: assert_never(arg)
If a type checker finds that a call to assert_never() is reachable, it will emit an error.
At runtime, this throws an exception when called.
function no_type_check
no_type_check(arg)
Decorator to indicate that annotations are not type hints.
The argument must be a class or function; if it is a class, it applies recursively to all methods and classes defined in that class (but not to methods defined in its superclasses or subclasses).
This mutates the function(s) or class(es) in place.
function no_type_check_decorator
no_type_check_decorator(decorator)
Decorator to give another decorator the @no_type_check effect.
This wraps the decorator with something that wraps the decorated function in @no_type_check.
function overload
overload(func)
Decorator for overloaded functions/methods.
In a stub file, place two or more stub definitions for the same function in a row, each decorated with @overload.
For example:
@overload def utf8(value: None) -> None: ... @overload def utf8(value: bytes) -> bytes: ... @overload def utf8(value: str) -> bytes: ...
In a non-stub file (i.e. a regular .py file), do the same but follow it with an implementation. The implementation should not be decorated with @overload:
@overload def utf8(value: None) -> None: ... @overload def utf8(value: bytes) -> bytes: ... @overload def utf8(value: str) -> bytes: ... def utf8(value): ... # implementation goes here
The overloads for a function can be retrieved at runtime using the get_overloads() function.
function get_overloads
get_overloads(func)
Return all defined overloads for func as a sequence.
function clear_overloads
clear_overloads()
Clear all overloads in the registry.
function final
final(f)
Decorator to indicate final methods and final classes.
Use this decorator to indicate to type checkers that the decorated method cannot be overridden, and decorated class cannot be subclassed.
For example:
class Base: @final def done(self) -> None: ... class Sub(Base): def done(self) -> None: # Error reported by type checker ...
@final class Leaf: ... class Other(Leaf): # Error reported by type checker ...
There is no runtime checking of these properties. The decorator attempts to set the __final__
attribute to True
on the decorated object to allow runtime introspection.
function NamedTuple
NamedTuple(typename, fields=None, **kwargs)
Typed version of namedtuple.
Usage:
class Employee(NamedTuple): name: str id: int
This is equivalent to:
Employee = collections.namedtuple('Employee', ['name', 'id'])
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types. (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
Employee = NamedTuple('Employee', [('name', str), ('id', int)])
function TypedDict
TypedDict(typename, fields=None, total=True, **kwargs)
A simple typed namespace. At runtime it is equivalent to a plain dict.
TypedDict creates a dictionary type such that a type checker will expect all instances to have a certain set of keys, where each key is associated with a value of a consistent type. This expectation is not checked at runtime.
Usage:
>>> class Point2D(TypedDict): ... x: int ... y: int ... label: str ... >>> a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK >>> b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check >>> Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') True
The type info can be accessed via the Point2D.annotations dict, and the Point2D.required_keys and Point2D.optional_keys frozensets. TypedDict supports an additional equivalent form:
Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
By default, all keys must be present in a TypedDict. It is possible to override this by specifying totality:
class Point2D(TypedDict, total=False): x: int y: int
This means that a Point2D TypedDict can have any of the keys omitted. A type checker is only expected to support a literal False or True as the value of the total argument. True is the default, and makes all items defined in the class body be required.
The Required and NotRequired special forms can also be used to mark individual keys as being required or not required:
class Point2D(TypedDict): x: int # the "x" key must always be present (Required is the default) y: NotRequired[int] # the "y" key can be omitted
See PEP 655 for more details on Required and NotRequired.
function reveal_type
reveal_type(obj: ~T) → ~T
Ask a static type checker to reveal the inferred type of an expression.
When a static type checker encounters a call to reveal_type()
, it will emit the inferred type of the argument:
x: int = 1 reveal_type(x)
Running a static type checker (e.g., mypy) on this example will produce output similar to 'Revealed type is "builtins.int"'.
At runtime, the function prints the runtime type of the argument and returns the argument unchanged.
function dataclass_transform
dataclass_transform(
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_specifiers: tuple[Union[type[Any], Callable[, Any]], ] = (),
**kwargs: Any
) → Callable[[~T], ~T]
Decorator to mark an object as providing dataclass-like behaviour.
The decorator can be applied to a function, class, or metaclass.
Example usage with a decorator function:
T = TypeVar("T")
@dataclass_transform() def create_model(cls: type[T]) -> type[T]: ... return cls
@create_model class CustomerModel: id: int name: str
On a base class:
@dataclass_transform() class ModelBase: ...
class CustomerModel(ModelBase): id: int name: str
On a metaclass:
@dataclass_transform() class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class CustomerModel(ModelBase): id: int name: str
The CustomerModel
classes defined above will be treated by type checkers similarly to classes created with @dataclasses.dataclass
. For example, type checkers will assume these classes have __init__
methods that accept id
and name
.
The arguments to this decorator can be used to customize this behavior:
eq_default
indicates whether theeq
parameter is assumed to beTrue
orFalse
if it is omitted by the caller.order_default
indicates whether theorder
parameter is assumed to be True or False if it is omitted by the caller.kw_only_default
indicates whether thekw_only
parameter is assumed to be True or False if it is omitted by the caller.field_specifiers
specifies a static list of supported classes or functions that describe fields, similar todataclasses.field()
.- Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.
At runtime, this decorator records its arguments in the __dataclass_transform__
attribute on the decorated object. It has no other runtime effect.
See PEP 681 for more details.
class Annotated
Add context-specific metadata to a type.
Example: Annotated[int, runtime_check.Unsigned] indicates to the hypothetical runtime_check module that this type is an unsigned int. Every other consumer of this type can ignore this metadata and treat this type as int.
The first argument to Annotated must be a valid type.
Details:
- It's an error to call
Annotated
with less than two arguments. - Access the metadata via the__metadata__
attribute:
assert Annotated[int, '$'].__metadata__ == ('$',)
- Nested Annotated types are flattened:
assert Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
- Instantiating an annotated type is equivalent to instantiating the underlying type:
assert Annotated[C, Ann1](5) == C(5)
- Annotated can be used as a generic type alias:
Optimized: TypeAlias = Annotated[T, runtime.Optimize()] assert Optimized[int] == Annotated[int, runtime.Optimize()]
OptimizedList: TypeAlias = Annotated[list[T], runtime.Optimize()] assert OptimizedList[int] == Annotated[list[int], runtime.Optimize()]
- Annotated cannot be used with an unpacked TypeVarTuple:
Variadic: TypeAlias = Annotated[*Ts, Ann1] # NOT valid
This would be equivalent to:
Annotated[T1, T2, T3, ..., Ann1]
where T1, T2 etc. are TypeVars, which would be invalid, because only one type should be passed to Annotated.
class Any
Special type indicating an unconstrained type.
- Any is compatible with every type.
- Any assumed to have all methods.
- All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
class BinaryIO
Typed version of the return of open() in binary mode.
property closed
property mode
property name
method close
close() → None
method fileno
fileno() → int
method flush
flush() → None
method isatty
isatty() → bool
method read
read(n: int = -1) → ~AnyStr
method readable
readable() → bool
method readline
readline(limit: int = -1) → ~AnyStr
method readlines
readlines(hint: int = -1) → List[~AnyStr]
method seek
seek(offset: int, whence: int = 0) → int
method seekable
seekable() → bool
method tell
tell() → int
method truncate
truncate(size: int = None) → int
method writable
writable() → bool
method write
write(s: Union[bytes, bytearray]) → int
method writelines
writelines(lines: List[~AnyStr]) → None
class ForwardRef
Internal wrapper to hold a forward reference.
method __init__
__init__(arg, is_argument=True, module=None, is_class=False)
class Generic
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as:
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows:
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
class IO
Generic base class for TextIO and BinaryIO.
This is an abstract, generic version of the return of open().
NOTE: This does not distinguish between the different possible classes (text vs. binary, read vs. write vs. read/write, append-only, unbuffered). The TextIO and BinaryIO subclasses below capture the distinctions between text vs. binary, which is pervasive in the interface; however we currently do not offer a way to track the other distinctions in the type system.
property closed
property mode
property name
method close
close() → None
method fileno
fileno() → int
method flush
flush() → None
method isatty
isatty() → bool
method read
read(n: int = -1) → ~AnyStr
method readable
readable() → bool
method readline
readline(limit: int = -1) → ~AnyStr
method readlines
readlines(hint: int = -1) → List[~AnyStr]
method seek
seek(offset: int, whence: int = 0) → int
method seekable
seekable() → bool
method tell
tell() → int
method truncate
truncate(size: int = None) → int
method writable
writable() → bool
method write
write(s: ~AnyStr) → int
method writelines
writelines(lines: List[~AnyStr]) → None
class NamedTupleMeta
class NewType
NewType creates simple unique types with almost zero runtime overhead.
NewType(name, tp) is considered a subtype of tp by static type checkers. At runtime, NewType(name, tp) returns a dummy callable that simply returns its argument.
Usage:
UserId = NewType('UserId', int)
def name_by_id(user_id: UserId) -> str: ...
UserId('user') # Fails type check
name_by_id(42) # Fails type check name_by_id(UserId(42)) # OK
num = UserId(5) + 1 # type: int
<a href="https://github.com/janosh/pymatviz/blob/main/typing/__init__#L3206"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>
### <kbd>method</kbd> `__init__`
```python
__init__(name, tp)
class ParamSpec
Parameter specification variable.
Usage:
P = ParamSpec('P')
Parameter specification variables exist primarily for the benefit of static type checkers. They are used to forward the parameter types of one callable to another callable, a pattern commonly found in higher order functions and decorators. They are only valid when used in Concatenate
, or as the first argument to Callable
, or as parameters for user-defined Generics. See class Generic for more information on generic types. An example for annotating a decorator:
T = TypeVar('T') P = ParamSpec('P')
def add_logging(f: Callable[P, T]) -> Callable[P, T]: '''A type-safe decorator to add logging to a function.''' def inner(*args: P.args, **kwargs: P.kwargs) -> T: logging.info(f'{f.__name__} was called') return f(*args, **kwargs) return inner
@add_logging def add_two(x: float, y: float) -> float: '''Add two numbers together.''' return x + y
Parameter specification variables can be introspected. e.g.:
P.name == 'P'
Note that only parameter specification variables defined in global scope can be pickled.
method __init__
__init__(name, bound=None, covariant=False, contravariant=False)
property args
property kwargs
class ParamSpecArgs
The args for a ParamSpec object.
Given a ParamSpec object P, P.args is an instance of ParamSpecArgs.
ParamSpecArgs objects have a reference back to their ParamSpec:
P.args.origin is P
This type is meant for runtime introspection and has no special meaning to static type checkers.
method __init__
__init__(origin)
class ParamSpecKwargs
The kwargs for a ParamSpec object.
Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs.
ParamSpecKwargs objects have a reference back to their ParamSpec:
P.kwargs.origin is P
This type is meant for runtime introspection and has no special meaning to static type checkers.
method __init__
__init__(origin)
class Protocol
Base class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0
def func(x: Proto) -> int: return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
---
<a href="https://github.com/janosh/pymatviz/blob/main/pymatviz/typing.py"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>
## <kbd>class</kbd> `SupportsAbs`
An ABC with one abstract method __abs__ that is covariant in its return type.
<a href="https://github.com/janosh/pymatviz/blob/main/typing/_no_init_or_replace_init#L1952"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>
### <kbd>function</kbd> `_no_init_or_replace_init`
```python
_no_init_or_replace_init(*args, **kwargs)
class SupportsBytes
An ABC with one abstract method bytes.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class SupportsComplex
An ABC with one abstract method complex.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class SupportsFloat
An ABC with one abstract method float.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class SupportsIndex
An ABC with one abstract method index.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class SupportsInt
An ABC with one abstract method int.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class SupportsRound
An ABC with one abstract method round that is covariant in its return type.
function _no_init_or_replace_init
_no_init_or_replace_init(*args, **kwargs)
class TextIO
Typed version of the return of open() in text mode.
property buffer
property closed
property encoding
property errors
property line_buffering
property mode
property name
property newlines
method close
close() → None
method fileno
fileno() → int
method flush
flush() → None
method isatty
isatty() → bool
method read
read(n: int = -1) → ~AnyStr
method readable
readable() → bool
method readline
readline(limit: int = -1) → ~AnyStr
method readlines
readlines(hint: int = -1) → List[~AnyStr]
method seek
seek(offset: int, whence: int = 0) → int
method seekable
seekable() → bool
method tell
tell() → int
method truncate
truncate(size: int = None) → int
method writable
writable() → bool
method write
write(s: ~AnyStr) → int
method writelines
writelines(lines: List[~AnyStr]) → None
class TypeVar
Type variable.
Usage:
T = TypeVar('T') # Can be anything A = TypeVar('A', str, bytes) # Must be str or bytes
Type variables exist primarily for the benefit of static type checkers. They serve as the parameters for generic types as well as for generic function definitions. See class Generic for more information on generic types. Generic functions work as follows:
def repeat(x: T, n: int) -> List[T]: '''Return a list containing n references to x.''' return [x]*n
def longest(x: A, y: A) -> A: '''Return the longest of two strings.''' return x if len(x) >= len(y) else y
The latter example's signature is essentially the overloading of (str, str) -> str and (bytes, bytes) -> bytes. Also note that if the arguments are instances of some subclass of str, the return type is still plain str.
At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError.
Type variables defined with covariant=True or contravariant=True can be used to declare covariant or contravariant generic types. See PEP 484 for more details. By default generic types are invariant in all type variables.
Type variables can be introspected. e.g.:
T.name == 'T' T.constraints == () T.covariant == False T.contravariant = False A.constraints == (str, bytes)
Note that only type variables defined in global scope can be pickled.
method __init__
__init__(name, *constraints, bound=None, covariant=False, contravariant=False)
class TypeVarTuple
Type variable tuple.
Usage:
Ts = TypeVarTuple('Ts') # Can be given any name
Just as a TypeVar (type variable) is a placeholder for a single type, a TypeVarTuple is a placeholder for an arbitrary number of types. For example, if we define a generic class using a TypeVarTuple:
class C(Generic[*Ts]): ...
Then we can parameterize that class with an arbitrary number of type
arguments:
C[int] # Fine C[int, str] # Also fine C[()] # Even this is fine
For more details, see PEP 646.
Note that only TypeVarTuples defined in global scope can be pickled.
method __init__
__init__(name)
class typing.io
Wrapper namespace for IO generic classes.
class typing.re
Wrapper namespace for re type aliases.
module uncertainty
Visualizations for assessing the quality of model uncertainty estimates.
Global Variables
- TYPE_CHECKING
function qq_gaussian
qq_gaussian(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
identity_line: 'bool | dict[str, Any]' = True
) → Axes
Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets y_true
and model predictions y_pred
.
Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line.
The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution.
Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot
Args:
y_true
(array | str): Ground truth targetsy_pred
(array | str): Model predictionsy_std
(array | dict[str, array] | str | list[str]): Model uncertainties either as array(s) (single or dict with labels if you have multiple sources of uncertainty) or column names in df.df
(pd.DataFrame, optional): DataFrame with y_true, y_pred and y_std columns.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add a parity line (y = x). Defaults to True. Pass a dict to customize line properties.
Returns:
plt.Axes
: matplotlib Axes object
function get_err_decay
get_err_decay(
y_true: 'ArrayLike',
y_pred: 'ArrayLike',
n_rand: 'int' = 100
) → tuple[ArrayLike, ArrayLike]
Calculate the model's error curve as samples are excluded from the calculation based on their absolute error.
Use in combination with get_std_decay to see what the error drop curve would look like if model error and uncertainty were perfectly rank-correlated.
Args:
y_true
(array): ground truth targetsy_pred
(array): model predictionsn_rand
(int, optional): Number of randomly ordered sample exclusions over which to average to estimate dummy performance. Defaults to 100.
Returns:
Tuple[array, array]
: Drop off in errors as data points are dropped based on model uncertainties and randomly, respectively.
function get_std_decay
get_std_decay(
y_true: 'ArrayLike',
y_pred: 'ArrayLike',
y_std: 'ArrayLike'
) → ArrayLike
Calculate the drop in model error as samples are excluded from the calculation based on the model's uncertainty.
For model's able to estimate their own uncertainty well, meaning predictions of larger error are associated with larger uncertainty, the error curve should fall off sharply at first as the highest-error points are discarded and slowly towards the end where only small-error samples with little uncertainty remain.
Note that even perfect model uncertainties would not mean this error drop curve coincides exactly with the one returned by get_err_decay as in some cases the model may have made an accurate prediction purely by chance in which case the error is small yet a good uncertainty estimate would still be large, leading the same sample to be excluded at different x-axis locations and thus the get_std_decay curve to lie higher.
Args:
y_true
(array): ground truth targetsy_pred
(array): model predictionsy_std
(array): model's predicted uncertainties
Returns:
array
: Error decay as data points are excluded by order of largest to smallest model uncertainties.
function error_decay_with_uncert
error_decay_with_uncert(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
df: 'DataFrame | None' = None,
n_rand: 'int' = 100,
percentiles: 'bool' = True,
ax: 'Axes | None' = None
) → Axes
Plot for assessing the quality of uncertainty estimates. If a model's uncertainty is well calibrated, i.e. strongly correlated with its error, removing the most uncertain predictions should make the mean error decay similarly to how it decays when removing the predictions of largest error.
Args:
y_true
(array | str): Ground truth regression targets.y_pred
(array | str): Model predictions.y_std
(array | dict[str, ArrayLike] | str | list[str]): Model uncertainties. Can be single or multiple uncertainties (e.g. aleatoric/epistemic/total uncertainty) as dict.n_rand
(int, optional): Number of shuffles from which to compute std.dev. of error decay by random ordering. Defaults to 100.df
(pd.DataFrame, optional): DataFrame with y_true, y_pred and y_std columns.percentiles
(bool, optional): Whether the x-axis shows percentiles or number of remaining samples in the MAE calculation. Defaults to True.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.
Note: If you're not happy with the default y_max of 1.1 * rand_mean, where rand_mean is mean of random sample exclusion, use ax.set(ylim=[None, some_value * ax.get_ylim()[1]]).
Returns:
plt.Axes
: matplotlib Axes object with plotted model error drop curve based on excluding data points by order of large to small model uncertainties.
module utils.data
Data processing utils: * df_ptable (DataFrame): Periodic table. * atomic_numbers (dict[str, int]): Map elements to atomic numbers. * element_symbols (dict[int, str]): Map atomic numbers to elements.
- bin_df_cols: Bin columns of a DataFrame.
- crystal_sys_from_spg_num: Get the crystal system for an international space group number.
- df_to_arrays: Convert DataFrame to arrays.
- html_tag: Wrap text in a span with custom style.
- normalize_to_dict: Normalize object or dict/list/tuple of them into to a dict.
- patch_dict: Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.
- si_fmt: Convert large numbers into human readable format using SI prefixes.
Global Variables
- TYPE_CHECKING
- ROOT
- atomic_numbers
- element_symbols
- Z
- symbol
function bin_df_cols
bin_df_cols(
df_in: 'DataFrame',
bin_by_cols: 'Sequence[str]',
group_by_cols: 'Sequence[str]' = (),
n_bins: 'int | Sequence[int]' = 100,
bin_counts_col: 'str' = 'bin_counts',
density_col: 'str' = '',
verbose: 'bool' = True
) → DataFrame
Bin columns of a DataFrame.
Args:
df_in
(pd.DataFrame): Input dataframe to bin.bin_by_cols
(Sequence[str]): Columns to bin.group_by_cols
(Sequence[str]): Additional columns to group by. Defaults to ().n_bins
(int): Number of bins to use. Defaults to 100.bin_counts_col
(str): Column name for bin counts. Defaults to "bin_counts".density_col
(str): Column name for density values. Defaults to "".verbose
(bool): If True, report df length reduction. Defaults to True.
Returns:
pd.DataFrame
: Binned DataFrame with original index name and values.
function crystal_sys_from_spg_num
crystal_sys_from_spg_num(spg: 'float') → CrystalSystem
Get the crystal system for an international space group number.
function df_to_arrays
df_to_arrays(
df: 'DataFrame | None',
*args: 'str | Sequence[str] | Sequence[ArrayLike]',
strict: 'bool' = True
) → list[ArrayLike | dict[str, ArrayLike]]
If df is None, this is a no-op: args are returned as-is. If df is a dataframe, all following args are used as column names and the column data returned as arrays (after dropping rows with NaNs in any column).
Args:
df
(pd.DataFrame | None): Optional pandas DataFrame.*args (list[ArrayLike | str])
: Arbitrary number of arrays or column names in df.strict
(bool, optional): If True, raise TypeError if df is not pd.DataFrame or None. If False, return args as-is. Defaults to True.
Raises:
ValueError
: If df is not None and any of the args is not a df column name.TypeError
: If df is not pd.DataFrame and not None.
Returns:
list[ArrayLike | dict[str, ArrayLike]]
: Array data for each column name or dictionary of column names and array data.
function html_tag
html_tag(
text: 'str',
tag: 'str' = 'span',
style: 'str' = '',
title: 'str' = ''
) → str
Wrap text in a span with custom style.
Style defaults to decreased font size and weight e.g. to display units in plotly labels and annotations.
Args:
text
(str): Text to wrap in span.tag
(str, optional): HTML tag name. Defaults to "span".style
(str, optional): CSS style string. Defaults to "". Special keys:"small"
: font-size: 0.8em; font-weight: lighter;"bold"
: font-weight: bold;"italic"
: font-style: italic;"underline"
: text-decoration: underline;title
(str | None, optional): Title attribute which displays additional information in a tooltip. Defaults to "".
Returns:
str
: HTML string with tag-wrapped text.
function normalize_to_dict
normalize_to_dict(
inputs: 'T | Sequence[T] | dict[str, T]',
cls: 'type[T]' = <class 'pymatgen.core.structure.Structure'>,
key_gen: 'Callable[[T], str]' = <function <lambda> at 0x7fda102968e0>
) → dict[str, T]
Normalize any kind of object or dict/list/tuple of them into to a dictionary.
Args:
inputs
: A single object, a sequence of objects, or a dictionary of objects.cls
(type[T], optional): The class of the objects to normalize. Defaults to pymatgen.core.Structure.key_gen
(Callable[[T], str], optional): A function that generates a key for each object. Defaults to using the object's formula, assuming the objects are pymatgen.core.(Structure|Molecule).
Returns: A dictionary of objects with keys as object formulas or given keys.
Raises:
TypeError
: If the input format is invalid.
function patch_dict
patch_dict(
dct: 'dict[Any, Any]',
*args: 'Any',
**kwargs: 'Any'
) → Generator[dict[Any, Any], None, None]
Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.
Useful e.g. for temporary plotly fig.layout mutations:
with patch_dict(fig.layout, showlegend=False): fig.write_image("plot.pdf")
Args:
dct
(dict): The dictionary to be patched.*args
: Only first element is read if present. A single dictionary containing the key-value pairs to patch.**kwargs
: The key-value pairs to patch, provided as keyword arguments.
Yields:
dict
: The patched dictionary incl. temporary updates.
function si_fmt
si_fmt(
val: 'float',
fmt: 'str' = '.1f',
sep: 'str' = '',
binary: 'bool' = False,
decimal_threshold: 'float' = 0.01
) → str
Convert large numbers into human readable format using SI prefixes.
Supports binary (1024) and metric (1000) mode.
https://nist.gov/pml/weights-and-measures/metric-si-prefixes
Args:
val
(int | float): Some numerical value to format.binary
(bool, optional): If True, scaling factor is 2^10 = 1024 else 1000. Defaults to False.fmt
(str): f-string format specifier. Configure precision and left/right padding in returned string. Defaults to ".1f". Can be used to ensure leading or trailing whitespace for shorter numbers. Seehttps
: //docs.python.org/3/library/string.html#format-specification-mini-language.sep
(str): Separator between number and postfix. Defaults to "".decimal_threshold
(float): abs(value) below 1 but above this threshold will be left as decimals. Only below this threshold is a greek suffix added (milli, micro, etc.). Defaults to 0.01. i.e. 0.01 -> "0.01" while 0.0099 -> "9.9m". Setting decimal_threshold=0.1 would format 0.01 as "10m" and leave 0.1 as is.
Returns:
str
: Formatted number.
module utils
pymatviz utility functions.
Global Variables
- PKG_DIR
- ROOT
- atomic_numbers
- element_symbols
- TEST_FILES
class ExperimentalWarning
Warning for experimental features.
module utils.plotting
Plotting-related utility functions.
Available functions: - annotate: Annotate a matplotlib or plotly figure with text. - apply_matplotlib_template: Set default matplotlib configurations for consistency. - get_cbar_label_formatter: Generate colorbar tick label formatter. - get_font_color: Get the font color used in a Matplotlib or Plotly figure. - get_fig_xy_range: Get the x and y range of a plotly or matplotlib figure. - luminance: Compute the luminance of a color. - pick_bw_for_contrast: Choose black or white text color for contrast. - pretty_label: Map metric keys to their pretty labels. - validate_fig: Decorator to validate the type of fig keyword argument.
Global Variables
- TYPE_CHECKING
- BACKENDS
- MATPLOTLIB
- PLOTLY
- VALID_FIG_NAMES
function annotate
annotate(text: 'str | Sequence[str]', fig: 'AxOrFig', **kwargs: 'Any') → AxOrFig
Annotate a matplotlib or plotly figure. Supports faceted plots plotly figure with trace with empty strings skipped.
Args:
text
(str): The text to use for annotation. If fig is plotly faceted, text can be a list of strings to annotate each subplot.fig
(plt.Axes | plt.Figure | go.Figure | None, optional): The matplotlib Axes, Figure or plotly Figure to annotate.**kwargs
: Additional arguments to pass to matplotlib's AnchoredText or plotly's fig.add_annotation().
Returns:
plt.Axes | plt.Figure | go.Figure
: The annotated figure.
Raises:
TypeError
: If fig is not a Matplotlib or Plotly figure.
function apply_matplotlib_template
apply_matplotlib_template() → None
Set default matplotlib configurations for consistency.
- Font size: 14 for readability.
- Savefig: Tight bounding box and 200 DPI for high-quality saved plots.
- Axes: Title size 16, bold weight for emphasis.
- Figure: DPI 200, title size 20, bold weight for better visibility.
- Layout: Enables constrained layout to reduce element overlap.
function get_cbar_label_formatter
get_cbar_label_formatter(
cbar_label_fmt: 'str',
values_fmt: 'str',
values_show_mode: "Literal['value', 'fraction', 'percent', 'off']",
sci_notation: 'bool',
default_decimal_places: 'int' = 1
) → Formatter
Generate colorbar tick label formatter.
Work differently for different values_show_mode: - "value/fraction" mode: Use cbar_label_fmt (or values_fmt) as is. - "percent" mode: Get number of decimal places to keep from fmt string, for example 1 from ".1%".
Args:
cbar_label_fmt
(str): f-string option for colorbar tick labels.values_fmt
(str): f-string option for tile values, would be used if cbar_label_fmt is "auto".values_show_mode
(str): The values display mode: - "off": Hide values. - "value": Display values as is. - "fraction": As a fraction of the total (0.10). - "percent": As a percentage of the total (10%).sci_notation
(bool): Whether to use scientific notation for values and colorbar tick labels.default_decimal_places
(int): Default number of decimal places to use if above fmt is invalid.
Returns: PercentFormatter or FormatStrFormatter.
function get_font_color
get_font_color(fig: 'AxOrFig') → str
Get the font color used in a Matplotlib figure/axes or a Plotly figure.
Args:
fig
(plt.Figure | plt.Axes | go.Figure): A Matplotlib or Plotly figure object.
Returns:
str
: The font color as a string (e.g., 'black', '#000000').
Raises:
TypeError
: If fig is not a Matplotlib or Plotly figure.
function luminance
luminance(color: 'str | tuple[float, float, float]') → float
Compute the luminance of a color as in https://stackoverflow.com/a/596243.
Args:
color
(tuple[float, float, float]): RGB color tuple with values in [0, 1].
Returns:
float
: Luminance of the color.
function pick_bw_for_contrast
pick_bw_for_contrast(
color: 'tuple[float, float, float] | str',
text_color_threshold: 'float' = 0.7
) → Literal['black', 'white']
Choose black or white text color for a given background color based on luminance.
Args:
color
(tuple[float, float, float] | str): RGB color tuple with values in [0, 1].text_color_threshold
(float, optional): Luminance threshold for choosing black or white text color. Defaults to 0.7.
Returns:
"black" | "white"
: depending on the luminance of the background color.
function pretty_label
pretty_label(key: 'str', backend: 'Backend') → str
Map metric keys to their pretty labels.
function validate_fig
validate_fig(func: 'Callable[P, R]') → Callable[P, R]
Decorator to validate the type of fig keyword argument in a function. fig MUST be a keyword argument, not a positional argument.
function get_fig_xy_range
get_fig_xy_range(
fig: 'Figure | Figure | Axes',
trace_idx: 'int' = 0
) → tuple[tuple[float, float], tuple[float, float]]
Get the x and y range of a plotly or matplotlib figure.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes.trace_idx
(int, optional): Index of the trace to use for measuring x/y limits. Defaults to 0. Unused if kaleido package is installed and the figure's actual x/y-range can be obtained from fig.full_figure_for_development().
Returns:
tuple[float, float, float, float]
: The x and y range of the figure in the format (x_min, x_max, y_min, y_max).
module utils.testing
Testing related utils.
Global Variables
- ROOT
- TEST_FILES
module xrd
Module for plotting XRD patterns using plotly.
Global Variables
- TYPE_CHECKING
- ValidHklFormats
- HklCompact
- HklFull
- HklNone
function format_hkl
format_hkl(hkl: 'tuple[int, int, int]', format_type: 'HklFormat') → str
Format hkl indices as a string.
Args:
hkl
(tuple[int, int, int]): The hkl indices to format.format_type
('compact' | 'full' | None): How to display the hkl indices.
Raises:
ValueError
: If format_type is not one of 'compact', 'full', or None.
function xrd_pattern
xrd_pattern(
patterns: 'PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct, dict[str, Any]]]',
peak_width: 'float' = 0.5,
annotate_peaks: 'float' = 5,
hkl_format: 'HklFormat' = 'compact',
show_angles: 'bool | None' = None,
wavelength: 'float' = 1.54184,
stack: "Literal['horizontal', 'vertical'] | None" = None,
subplot_kwargs: 'dict[str, Any] | None' = None,
subtitle_kwargs: 'dict[str, Any] | None' = None
) → Figure
Create a plotly figure of XRD patterns from a pymatgen DiffractionPattern, from a pymatgen Structure, or a dictionary of either of them.
Args: patterns (PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct,
dict[str, Any]]])
: Either a single DiffractionPattern or Structure object, or a dictionary where keys are legend labels and values are either DiffractionPattern/Structure objects or tuples of (DiffractionPattern/Structure, kwargs) for customizing individual patterns.peak_width
(float): Width of the diffraction peaks in degrees. Default is 0.5.annotate_peaks
(float): Controls peak annotation. If int, annotates that many highest peaks. If float, should be in (0, 1) which will annotate peaks higher than that fraction of the highest peak. Default is 5.hkl_format
(HklFormat): Format for hkl indices. One of 'compact' (ex: '100'),'full' (ex
: '(1, 0, 0)'), or None for no hkl indices. Default is 'compact'.show_angles
(bool | None): Whether to show angles in peak annotations. If None, it will default to True if plotting 1 or 2 patterns, False for 3 or more patterns.wavelength
(float): X-ray wavelength for the XRD calculation (in Angstroms). Default is 1.54184 (Cu K-alpha). Only used if patterns argument contains Structures.stack
(Literal["horizontal", "vertical"] | None): If set to "horizontal" or "vertical", creates separate subplots for each pattern. Default is None (all patterns in one plot).subplot_kwargs
(dict[str, Any] | None): Passed to make_subplots. Can be used tocontrol spacing between subplots, e.g. {'vertical_spacing'
: 0.02}.subtitle_kwargs
(dict[str, Any] | None): Override default subplot title settings. E.g. dict(font_size=14). Default is None.
Raises:
ValueError
: If annotate_peaks is not a positive int or a float in (0, 1).TypeError
: If patterns is not a DiffractionPattern, Structure or a dict of them.
Returns:
go.Figure
: A plotly figure of the XRD pattern(s).