« home

API

module bar

Bar plots.

Global Variables


function spacegroup_bar

spacegroup_bar(
    data: 'Sequence[int | str | Structure] | Series',
    show_counts: 'bool' = True,
    xticks: "Literal['all', 'crys_sys_edges'] | int" = 20,
    show_empty_bins: 'bool' = False,
    ax: 'Axes | None' = None,
    backend: 'Backend' = 'plotly',
    text_kwargs: 'dict[str, Any] | None' = None,
    log: 'bool' = False,
    **kwargs: 'Any'
) → Axes | Figure

Plot a histogram of spacegroups shaded by crystal system.

Args:

Returns:

module colors

Colors used in pymatviz.

Global Variables

module coordination

Visualizations of coordination numbers distributions.

Global Variables


function create_hover_text

create_hover_text(
    struct_key: str,
    elem_symbol: str,
    cn: int,
    count: int,
    hover_data: dict[str, str],
    data: dict[str, Any],
    is_single_structure: bool
) → str

Create hover text for a single bar in the histogram.


function normalize_get_neighbors

normalize_get_neighbors(
    strategy: float | NearNeighbors | type[NearNeighbors]
) → Callable[[PeriodicSite, Structure], list[dict[str, Any]]]

Normalize get_neighbors function.


function coordination_hist

coordination_hist(
    structures: Structure | dict[str, Structure] | Sequence[Structure],
    strategy: float | NearNeighbors | type[NearNeighbors] = 3.0,
    split_mode: SplitMode | str = SplitMode.by_element,
    bar_mode: Literal['group', 'stack'] = 'stack',
    hover_data: Sequence[str] | dict[str, str] | None = None,
    element_color_scheme: ElemColorScheme | dict[str, str] = Jmol,
    annotate_bars: bool | dict[str, Any] = False,
    bar_kwargs: dict[str, Any] | None = None
) → Figure

Create a plotly histogram of coordination numbers for given structure(s).

Args:

Returns: A plotly Figure object containing the histogram.


function coordination_vs_cutoff_line

coordination_vs_cutoff_line(
    structures: Structure | dict[str, Structure] | Sequence[Structure],
    strategy: tuple[float, float] | NearNeighbors | type[NearNeighbors] = (1.0, 5.0),
    num_points: int = 50,
    element_color_scheme: ElemColorScheme | dict[str, str] = Jmol,
    subplot_kwargs: dict[str, Any] | None = None
) → Figure

Create a plotly line plot of cumulative coordination numbers vs cutoff distance.

Args:

Returns: A plotly Figure object containing the line plot.


function calculate_average_cn

calculate_average_cn(
    structure: Structure,
    element: str,
    get_neighbors: Callable[[PeriodicSite, Structure], list[dict[str, Any]]]
) → float

Calculate the average coordination number for a given element in a structure.


class SplitMode

How to split the coordination number histogram into subplots.

module cumulative

Plot the cumulative distribution of residuals and absolute errors.

Global Variables


function cumulative_residual

cumulative_residual(
    res: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution for the residuals (y - mu).

Args:

Returns:


function cumulative_error

cumulative_error(
    abs_err: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution of abs(y_true - y_pred).

Args:

Returns:

module enums

Enums used as keys/accessors for dicts and dataframes across Matbench Discovery.

Global Variables


class LabelEnum

StrEnum with optional label and description attributes plus dict() methods.

Simply add label and description as a tuple starting with the key's value.


class Key

Keys used to access dataframes columns, organized by semantic groups.


class Task

What kind of task is being performed.


class Model

Model names.


class ElemCountMode

Mode of counting elements in a chemical formula.


class ElemColorMode

Mode of coloring elements in structure visualizations or periodic table plots.


class ElemColorScheme

Names of element color palettes.

Used e.g. in structure visualizations and periodic table plots.


class SiteCoords

Site coordinate representations.

module histogram

Histograms.

Global Variables


function spacegroup_hist

spacegroup_hist(*args: 'Any', **kwargs: 'Any') → Axes | Figure

Alias for spacegroup_bar.


function elements_hist

elements_hist(
    formulas: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    log: 'bool' = False,
    keep_top: 'int | None' = None,
    ax: 'Axes | None' = None,
    bar_values: "Literal['percent', 'count'] | None" = 'percent',
    h_offset: 'int' = 0,
    v_offset: 'int' = 10,
    rotation: 'int' = 45,
    fontsize: 'int' = 12,
    **kwargs: 'Any'
) → Axes

Plot a histogram of elements (e.g. to show occurrence in a dataset).

Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).

Args:

Returns:


function histogram

histogram(
    values: 'Sequence[float] | dict[str, Sequence[float]]',
    bins: 'int | Sequence[float] | str' = 200,
    x_range: 'tuple[float | None, float | None] | None' = None,
    density: 'bool' = False,
    bin_width: 'float' = 1.2,
    log_y: 'bool' = False,
    backend: 'Backend' = 'plotly',
    fig_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Figure | Figure

Get a histogram with plotly (default) or matplotlib backend but using fast numpy pre-processing before handing the data off to the plot function.

Very common use case when dealing with large datasets so worth having a dedicated function for it. Two advantages over the matplotlib/plotly native histograms are much faster and much smaller file sizes (when saving plotly figs as HTML since plotly saves a complete copy of the data to disk from which it recomputes the histogram on the fly to render the figure). Speedup example:

gaussian = np.random.normal(0, 1, 1_000_000_000) plot_histogram(gaussian) # takes 17s px.histogram(gaussian) # ran for 3m45s before crashing the Jupyter kernel

Args:

Returns:

module io

I/O utilities for saving figures and dataframes to various image formats.

Global Variables


function save_fig

save_fig(
    fig: 'Figure | Figure | Axes',
    path: 'str',
    plotly_config: 'dict[str, Any] | None' = None,
    env_disable: 'Sequence[str]' = ('CI',),
    pdf_sleep: 'float' = 0.6,
    style: 'str' = '',
    prec: 'int | None' = None,
    template: 'str | None' = None,
    transparent_bg: 'bool' = True,
    **kwargs: 'Any'
) → None

Write a plotly or matplotlib figure to disk (as HTML/PDF/SVG/...).

If the file is has .svelte extension, insert {...$$props} into the figure's top-level div so it can be later styled and customized from Svelte code.

Args:


function save_and_compress_svg

save_and_compress_svg(
    fig: 'Figure | Figure | Axes',
    filename: 'str',
    transparent_bg: 'bool' = True
) → None

Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with svgo CLI if available in PATH.

If filename does not include .svg extension and is not absolute, will be treated as relative to assets/ folder. This function is mostly meant for pymatviz internal use.

Args:

Raises:


function df_to_pdf

df_to_pdf(
    styler: 'Styler',
    file_path: 'str | Path',
    crop: 'bool' = True,
    size: 'str | None' = None,
    style: 'str' = '',
    styler_css: 'bool | dict[str, str]' = True,
    **kwargs: 'Any'
) → None

Export a pandas Styler to PDF with WeasyPrint.

Args:


function normalize_and_crop_pdf

normalize_and_crop_pdf(
    file_path: 'str | Path',
    on_gs_not_found: "Literal['ignore', 'warn', 'error']" = 'warn'
) → None

Normalize a PDF using Ghostscript and then crop it. Without gs normalization, pdfCropMargins sometimes corrupts the PDF.

Args:


function df_to_html_table

df_to_html_table(*args: 'Any', **kwargs: 'Any') → str

function df_to_html

df_to_html(
    styler: 'Styler',
    file_path: 'str | Path | None' = None,
    inline_props: 'str | None' = '{{...$$props}} ',
    pre_table: 'str | None' = '',
    styles: 'str | None' = 'table { overflow: scroll; max-width: 100%; display: block; }\ntable {\n    scrollbar-width: none;  /* Firefox */\n}\ntable::-webkit-scrollbar {\n    display: none;  /* Safari and Chrome */\n}',
    styler_css: 'bool | dict[str, str]' = True,
    use_sortable: 'bool' = True,
    use_tooltips: 'bool' = True,
    post_process: 'Callable[[str], str] | None' = None,
    **kwargs: 'Any'
) → str

Convert a pandas Styler to a svelte table.

Args:

Returns:


function df_to_svg

df_to_svg(
    obj: 'DataFrame | Styler',
    file_path: 'str | Path',
    font_size: 'int' = 14,
    compress: 'bool' = True,
    **kwargs: 'Any'
) → Figure

Export a pandas DataFrame or Styler to an SVG file and optionally compress it.

TODO The SVG output has annoying margins that proved hard to remove. The goal is for this function to auto-crop the SVG viewBox to the content in the future.

Args:

Returns:

Raises:


class BufferedIOBase

Base class for buffered IO objects.

The main difference with RawIOBase is that the read() method supports omitting the size argument, and does not have a default implementation that defers to readinto().

In addition, read(), readinto() and write() may raise BlockingIOError if the underlying raw stream is in non-blocking mode and not ready; unlike their raw counterparts, they will never return None.

A typical implementation should not inherit from a RawIOBase implementation, but wrap one.


class IOBase

The abstract base class for all I/O classes.

This class provides dummy implementations for many methods that derived classes can override selectively; the default implementations represent a file that cannot be read, written or seeked.

Even though IOBase does not declare read, readinto, or write because their signatures will vary, implementations and clients should consider those methods part of the interface. Also, implementations may raise UnsupportedOperation when operations they do not support are called.

The basic type used for binary data read from or written to a file is bytes. Other bytes-like objects are accepted as method arguments too. In some cases (such as readinto), a writable object is required. Text I/O classes work with str data.

Note that calling any method (except additional calls to close(), which are ignored) on a closed stream should raise a ValueError.

IOBase (and its subclasses) support the iterator protocol, meaning that an IOBase object can be iterated over yielding the lines in a stream.

IOBase also supports the :keyword:with statement. In this example, fp is closed after the suite of the with statement is complete:

with open('spam.txt', 'r') as fp: fp.write('Spam and eggs!')


class RawIOBase

Base class for raw binary I/O.


class TextIOBase

Base class for text I/O.

This class provides a character and line based interface to stream I/O. There is no readinto method because Python's character strings are immutable.


class UnsupportedOperation


class TqdmDownload

Progress bar for urllib.request.urlretrieve file download.

Adapted from official TqdmUpTo example. See https://github.com/tqdm/tqdm/blob/4c956c20b83be4312460fc0c4812eeb3fef5e7df/README.rst#hooks-and-callbacks

method __init__

__init__(*args: 'Any', **kwargs: 'Any') → None

Sets default values appropriate for file downloads for unit, unit_scale, unit_divisor, miniters, desc.


property format_dict

Public API for read-only member access.


method update_to

update_to(
    n_blocks: 'int' = 1,
    block_size: 'int' = 1,
    total_size: 'int | None' = None
) → bool | None

Update hook for urlretrieve.

Args:

Returns:

module phonons

Plotting functions for pymatgen phonon band structures and density of states.

Global Variables


function pretty_sym_point

pretty_sym_point(symbol: 'str') → str

Convert a symbol to a pretty-printed version.


function get_band_xaxis_ticks

get_band_xaxis_ticks(
    band_struct: 'PhononBands',
    branches: 'Sequence[str] | set[str]' = ()
) → tuple[list[float], list[str]]

Get all ticks and labels for a band structure plot.

Returns:


function phonon_bands

phonon_bands(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    line_kwargs: 'dict[str, Any] | None' = None,
    branches: 'Sequence[str]' = (),
    branch_mode: 'BranchMode' = 'union',
    shaded_ys: 'dict[tuple[YMin | YMax, YMin | YMax], dict[str, Any]] | bool | None' = None,
    **kwargs: 'Any'
) → Figure

Plot single or multiple pymatgen band structures using Plotly, focusing on the minimum set of overlapping branches.

Warning: Only tested with phonon band structures so far but plan is to extend to electronic band structures.

Args: band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.

Returns:


function phonon_dos

phonon_dos(
    doses: 'PhononDos | dict[str, PhononDos]',
    stack: 'bool' = False,
    sigma: 'float' = 0,
    units: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz',
    normalize: "Literal['max', 'sum', 'integral'] | None" = None,
    last_peak_anno: 'str | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS using Plotly.

Args:

Returns:


function convert_frequencies

convert_frequencies(
    frequencies: 'ndarray',
    unit: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz'
) → ndarray

Convert frequencies from THz to specified units.

Args:

Returns:


function phonon_bands_and_dos

phonon_bands_and_dos(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    doses: 'PhononDos | dict[str, PhononDos]',
    bands_kwargs: 'dict[str, Any] | None' = None,
    dos_kwargs: 'dict[str, Any] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    all_line_kwargs: 'dict[str, Any] | None' = None,
    per_line_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS and band structure using Plotly.

Args:

Returns:


class PhononDBDoc

Dataclass for phonon DB docs.

method __init__

__init__(
    structure: 'Structure',
    phonon_bandstructure: 'PhononBands',
    phonon_dos: 'PhononDos',
    free_energies: 'list[float]',
    internal_energies: 'list[float]',
    heat_capacities: 'list[float]',
    entropies: 'list[float]',
    temps: 'list[float] | None' = None,
    has_imaginary_modes: 'bool | None' = None,
    primitive: 'Structure | None' = None,
    supercell: 'list[list[int]] | None' = None,
    nac_params: 'dict[str, Any] | None' = None,
    thermal_displacement_data: 'dict[str, Any] | None' = None,
    mp_id: 'str | None' = None,
    formula: 'str | None' = None
) → None

module powerups.both

Powerups that can be applied to both matplotlib and plotly figures.

Global Variables


function annotate_metrics

annotate_metrics(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    fig: 'AxOrFig | None' = None,
    metrics: 'dict[str, float] | Sequence[str]' = ('MAE', 'R2'),
    prefix: 'str' = '',
    suffix: 'str' = '',
    fmt: 'str' = '.3',
    **kwargs: 'Any'
) → AnchoredText

Provide a set of x and y values of equal length and an optional Axes object on which to print the values' mean absolute error and R^2 coefficient of determination.

Args:

Returns:


function add_identity_line

add_identity_line(
    fig: 'Figure | Figure | Axes',
    line_kwargs: 'dict[str, Any] | None' = None,
    trace_idx: 'int' = 0,
    retain_xy_limits: 'bool' = False,
    **kwargs: 'Any'
) → Figure

Add a line shape to the background layer of a plotly figure spanning from smallest to largest x/y values in the trace specified by trace_idx.

Args:

Raises:

Returns:


function add_best_fit_line

add_best_fit_line(
    fig: 'Figure | Figure | Axes',
    xs: 'ArrayLike' = (),
    ys: 'ArrayLike' = (),
    trace_idx: 'int | None' = None,
    line_kwargs: 'dict[str, Any] | None' = None,
    annotate_params: 'bool | dict[str, Any]' = True,
    warn: 'bool' = True,
    **kwargs: 'Any'
) → Figure

Add line of best fit according to least squares to a plotly or matplotlib figure.

Args:

Raises:

Returns:

module powerups.matplotlib

Powerups for matplotlib figures.

Global Variables


function with_marginal_hist

with_marginal_hist(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    fig: 'Figure | Axes | None' = None
) → Axes

Call before creating a matplotlib figure and use the returned ax_main for all subsequent plotting ops to create a grid of plots with the main plot in the lower left and narrow histograms along its x- and/or y-axes displayed above and near the right edge.

Args:

Returns:


function annotate_bars

annotate_bars(
    ax: 'Axes | None' = None,
    v_offset: 'float' = 10,
    h_offset: 'float' = 0,
    labels: 'Sequence[str | int | float] | None' = None,
    fontsize: 'int' = 14,
    y_max_headroom: 'float' = 1.2,
    adjust_test_pos: 'bool' = False,
    **kwargs: 'Any'
) → None

Annotate each bar in bar plot with a label.

Args:

module powerups

Powerups such as parity lines, annotations, marginals, menu buttons, etc for matplotlib and plotly figures.

Global Variables

module powerups.plotly

Powerups for plotly figures.

Global Variables


function add_ecdf_line

add_ecdf_line(
    fig: 'Figure',
    values: 'ArrayLike' = (),
    trace_idx: 'int' = 0,
    trace_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Figure

Add an empirical cumulative distribution function (ECDF) line to a plotly figure.

Support for matplotlib planned but not implemented. PRs welcome.

Args:

Returns:

module process_data

pymatviz utility functions.

Global Variables


function count_elements

count_elements(
    values: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    exclude_elements: 'Sequence[str]' = (),
    fill_value: 'float | None' = None
) → Series

Count element occurrence in list of formula strings or dict-like compositions. If passed values are already a map from element symbol to counts, ensure the data is a pd.Series filled with zero values for missing element symbols.

Provided as standalone function for external use or to cache long computations. Caching long element counts is done by refactoring ptable_heatmap(long_list_of_formulas) # slow to elem_counts = count_elements(long_list_of_formulas) # slow ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot

Args:

Returns:

module ptable

matplotlib and plotly periodic table figures.

module ptable.ptable_matplotlib

Various periodic table heatmaps with matplotlib and plotly.

Global Variables


function ptable_heatmap

ptable_heatmap(
    data: 'DataFrame | Series | dict[str, list[list[float]]] | PTableData',
    colormap: 'str' = 'viridis',
    exclude_elements: 'Sequence[str]' = (),
    overwrite_tiles: 'dict[ElemStr, OverwriteTileValueColor] | None' = None,
    infty_color: 'ColorType' = 'lightskyblue',
    nan_color: 'ColorType' = 'lightgrey',
    log: 'bool' = False,
    sci_notation: 'bool' = False,
    tile_size: 'tuple[float, float]' = (0.75, 0.75),
    on_empty: "Literal['hide', 'show']" = 'show',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    f_block_voffset: 'float' = 0,
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    text_colors: "Literal['auto'] | ColorType | dict[ElemStr, ColorType]" = 'auto',
    symbol_pos: 'tuple[float, float] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    value_show_mode: "Literal['value', 'fraction', 'percent', 'off']" = 'value',
    value_pos: 'tuple[float, float] | None' = None,
    value_fmt: 'str' = 'auto',
    value_kwargs: 'dict[str, Any] | None' = None,
    show_cbar: 'bool' = True,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.05),
    cbar_range: 'tuple[float | None, float | None]' = (None, None),
    cbar_label_fmt: 'str' = 'auto',
    cbar_title: 'str' = 'Element Count',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_kwargs: 'dict[str, Any] | None' = None,
    return_type: "Literal['figure', 'axes']" = 'axes',
    colorscale: 'str | None' = None,
    heat_mode: "Literal['value', 'fraction', 'percent'] | None" = None,
    show_values: 'bool | None' = None,
    fmt: 'str | None' = None,
    cbar_fmt: 'str | None' = None,
    show_scale: 'bool | None' = None
) → Axes

Plot a heatmap across the periodic table.

Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

Returns:


function ptable_heatmap_ratio

ptable_heatmap_ratio(
    values_num: 'ElemValues',
    values_denom: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    normalize: 'bool' = False,
    infty_color: 'ColorType' = 'lightskyblue',
    zero_color: 'ColorType' = 'lightgrey',
    zero_tol: 'float' = 1e-06,
    zero_symbol: 'str' = 'ZERO',
    not_in_numerator: 'tuple[str, str] | None' = ('lightgray', 'gray: not in 1st list'),
    not_in_denominator: 'tuple[str, str] | None' = ('lightskyblue', 'blue: not in 2nd list'),
    not_in_either: 'tuple[str, str] | None' = ('white', 'white: not in either'),
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Element Ratio',
    **kwargs: 'Any'
) → figure

Display the ratio of two maps from element symbols to heat values or of two sets of compositions.

Args:

--- Ratio heatmap specific --- 

Returns:


function ptable_heatmap_splits

ptable_heatmap_splits(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    start_angle: 'float' = 135,
    colormap: 'str | Colormap' = 'viridis',
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fb27373b6a0>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.5),
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot evenly-split heatmaps, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

Notes:

Default figsize is set to (0.75 * n_groups, 0.75 * n_periods).

Returns:


function ptable_hists

ptable_hists(
    data: 'DataFrame | Series | dict[ElemStr, list[float]]',
    bins: 'int' = 20,
    x_range: 'tuple[float | None, float | None] | None' = None,
    log: 'bool' = False,
    colormap: 'str | Colormap | None' = 'viridis',
    on_empty: "Literal['show', 'hide']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    cbar_axis: "Literal['x', 'y']" = 'x',
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fb27373b7e0>,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: "Literal['symbol', 'background', 'both', 'off']" = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot histograms for each element laid out in a periodic table.

Args:

Returns:


function ptable_scatters

ptable_scatters(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    colormap: 'str | Colormap | None' = None,
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fb27373b920>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: "Literal['symbol', 'background', 'both', 'off']" = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Make scatter plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,


function ptable_lines

ptable_lines(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fb27373ba60>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: "Literal['symbol', 'background', 'both', 'off']" = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Line plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

module ptable.ptable_plotly

Periodic table plots powered by plotly.

Global Variables


function ptable_heatmap_plotly

ptable_heatmap_plotly(
    values: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
    show_scale: 'bool' = True,
    show_values: 'bool' = True,
    heat_mode: "Literal['value', 'fraction', 'percent']" = 'value',
    fmt: 'str | None' = None,
    hover_props: 'Sequence[str] | dict[str, str] | None' = None,
    hover_data: 'dict[str, str | int | float] | Series | None' = None,
    font_colors: 'Sequence[str]' = (),
    gap: 'float' = 5,
    font_size: 'int | None' = None,
    bg_color: 'str | None' = None,
    nan_color: 'str' = '#eff',
    color_bar: 'dict[str, Any] | None' = None,
    cscale_range: 'tuple[float | None, float | None]' = (None, None),
    exclude_elements: 'Sequence[str]' = (),
    log: 'bool' = False,
    fill_value: 'float | None' = None,
    element_symbol_map: 'dict[str, str] | None' = None,
    label_map: 'dict[str, str] | Callable[[str], str] | Literal[False] | None' = None,
    border: 'dict[str, Any] | None | Literal[False]' = None,
    scale: 'float' = 1.0,
    **kwargs: 'Any'
) → Figure

Create a Plotly figure with an interactive heatmap of the periodic table. Supports hover tooltips with custom data or atomic reference data like electronegativity, atomic_radius, etc. See kwargs hover_data and hover_props, resp.

Args:

Returns:

module rainclouds

Raincloud plots.

Global Variables


function rainclouds

rainclouds(
    data: 'dict[str, Sequence[float] | tuple[DataFrame, str]]',
    orientation: "Literal['h', 'v']" = 'h',
    alpha: 'float' = 0.7,
    width_viol: 'float' = 0.3,
    width_box: 'float' = 0.05,
    jitter: 'float' = 0.01,
    point_size: 'float' = 3,
    bw: 'float' = 0.2,
    cut: 'float' = 0.0,
    scale: "Literal['area', 'count', 'width']" = 'area',
    rain_offset: 'float' = -0.25,
    offset: 'float | None' = None,
    hover_data: 'Sequence[str] | dict[str, Sequence[str]] | None' = None,
    show_violin: 'bool' = True,
    show_box: 'bool' = True,
    show_points: 'bool' = True
) → Figure

Create a raincloud plot for multiple datasets using Plotly.

This plot type was proposed in https://wellcomeopenresearch.org/articles/4-63/v2. It is a vertical stack of:

  1. violin plot (the cloud) 2. box plot (the umbrella) 3. strip plot (the rain)

Args:

Returns:

module rdf.helpers

Helper functions for radial distribution functions (RDFs) of pymatgen structures.

Global Variables


function calculate_rdf

calculate_rdf(
    structure: 'Structure',
    center_species: 'str | None' = None,
    neighbor_species: 'str | None' = None,
    cutoff: 'float' = 15,
    n_bins: 'int' = 75,
    pbc: 'PbcLike' = (True, True, True)
) → tuple[ndarray, ndarray]

Calculate the radial distribution function (RDF) for a given structure.

If center_species and neighbor_species are provided, calculates the partial RDF for the specified element pair. Otherwise, calculates the full RDF.

Args:

Returns:

module rdf

Radial distribution function plots (RDFs).

module rdf.plotly

Radial distribution functions (RDFs) of pymatgen structures using plotly.

The main function, pairwise_rdfs, generates a plotly figure with facets for each pair of elements in the given structure. It supports customization of cutoff distance, bin size, specific element pairs to plot, reference line.

Example usage: structure = Structure(...) # Create or load a pymatgen Structure fig = pairwise_rdfs(structure, bin_size=0.1) fig.show()

Global Variables


function element_pair_rdfs

element_pair_rdfs(
    structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
    cutoff: 'float' = 15,
    n_bins: 'int' = 75,
    bin_size: 'float | None' = None,
    element_pairs: 'list[tuple[str, str]] | None' = None,
    reference_line: 'dict[str, Any] | None' = None,
    colors: 'Sequence[str] | None' = None,
    line_styles: 'Sequence[str] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure

Generate a plotly figure of pairwise radial distribution functions (RDFs) for all (or a subset of) element pairs in one or multiple structures.

Args:

Returns:

Raises:


function full_rdf

full_rdf(
    structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
    cutoff: 'float' = 15,
    n_bins: 'int' = 75,
    bin_size: 'float | None' = None,
    reference_line: 'dict[str, Any] | None' = None,
    colors: 'Sequence[str] | None' = None,
    line_styles: 'Sequence[str] | None' = None
) → Figure

Generate a plotly figure of full radial distribution functions (RDFs) for one or multiple structures.

Args:

Returns:

Raises:

module relevance

Plots for evaluating classifier performance.

Global Variables


function roc_curve

roc_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
) → tuple[float, Axes]

Plot the receiver operating characteristic curve of a binary classifier given target labels and predicted probabilities for the positive class.

Args:

Returns:


function precision_recall_curve

precision_recall_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
) → tuple[float, Axes]

Plot the precision recall curve of a binary classifier.

Args:

Returns:

module sankey

Sankey diagram for comparing distributions in two dataframe columns.

Global Variables


function sankey_from_2_df_cols

sankey_from_2_df_cols(
    df: 'DataFrame',
    cols: 'list[str]',
    labels_with_counts: "bool | Literal['percent']" = True,
    **kwargs: 'Any'
) → Figure

Plot two columns of a dataframe as a Plotly Sankey diagram.

Args:

Raises:

Returns:

module scatter

Parity, residual and density plots.

Global Variables


function density_scatter

density_scatter(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    log_density: 'bool' = True,
    hist_density_kwargs: 'dict[str, Any] | None' = None,
    color_bar: 'bool | dict[str, Any]' = True,
    xlabel: 'str | None' = None,
    ylabel: 'str | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    stats: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Axes

Scatter plot colored by density using matplotlib backend.

Args:

Returns:


function density_scatter_plotly

density_scatter_plotly(
    df: 'DataFrame',
    x: 'str',
    y: 'str',
    density: "Literal['kde', 'empirical'] | None" = None,
    log_density: 'bool | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any] | None' = None,
    stats: 'bool | dict[str, Any]' = True,
    n_bins: 'int | None | Literal[False]' = None,
    bin_counts_col: 'str | None' = None,
    facet_col: 'str | None' = None,
    **kwargs: 'Any'
) → Figure

Scatter plot colored by density using plotly backend.

This function uses binning as implemented in bin_df_cols() to reduce the number of points plotted which enables plotting millions of data points and reduced file size for interactive plots. All outlier points will be plotted as is but overlapping points (tolerance for overlap determined by n_bins) will be merged into a single point with a new column bin_counts_col counting the number of points in that bin.

Args:

Returns:


function scatter_with_err_bar

scatter_with_err_bar(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    xerr: 'ArrayLike | None' = None,
    yerr: 'ArrayLike | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    title: 'str | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty.

Args:

Returns:


function density_hexbin

density_hexbin(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    weights: 'ArrayLike | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    cbar_label: 'str | None' = 'Density',
    cbar_coords: 'tuple[float, float, float, float]' = (0.95, 0.03, 0.03, 0.7),
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights.

Args:

Returns:


function density_scatter_with_hist

density_scatter_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot colored (and optionally sorted) by density with histograms along each dimension.


function density_hexbin_with_hist

density_hexbin_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by density or by third dimension passed color_by with histograms along each dimension.


function residual_vs_actual

residual_vs_actual(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    xlabel: 'str' = 'Actual value',
    ylabel: 'str' = 'Residual ($y_\\mathrm{true} - y_\\mathrm{pred}$)',
    **kwargs: 'Any'
) → Axes

Plot targets on the x-axis vs residuals (y_err = y_true - y_pred) on the y-axis.

Args:

Returns:

module structure_viz.helpers

Helper functions for 2D and 3D plots of pymatgen structures with plotly.

Global Variables


function get_image_sites

get_image_sites(
    site: 'PeriodicSite',
    lattice: 'Lattice',
    tol: 'float' = 0.02
) → ndarray

Get images for a given site in a lattice.

Images are sites that are integer translations of the given site that are within a tolerance of the unit cell edges.

Args:

Returns:


function unit_cell_to_lines

unit_cell_to_lines(cell: 'ArrayLike') → tuple[ArrayLike, ArrayLike, ArrayLike]

Convert lattice vectors to plot lines.

Args:

Returns: tuple[np.array, np.array, np.array]: - Lines - z-indices that sort plot elements into out-of-plane layers - lines used to plot the unit cell


function get_elem_colors

get_elem_colors(
    elem_colors: 'ElemColorScheme | dict[str, str]'
) → dict[str, str]

Get element colors based on the provided scheme or custom dictionary.


function get_atomic_radii

get_atomic_radii(
    atomic_radii: 'float | dict[str, float] | None'
) → dict[str, float]

Get atomic radii based on the provided input.


function generate_site_label

generate_site_label(
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]",
    site_idx: 'int',
    majority_species: 'Species'
) → str

Generate a label for a site based on the provided labeling scheme.


function get_subplot_title

get_subplot_title(
    struct_i: 'Structure',
    struct_key: 'Any',
    idx: 'int',
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None'
) → dict[str, Any]

Generate a subplot title based on the provided function or default logic.


function get_site_hover_text

get_site_hover_text(
    site: 'PeriodicSite',
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]',
    majority_species: 'Species'
) → str

Generate hover text for a site based on the hover template.


function draw_site

draw_site(
    fig: 'Figure',
    site: 'PeriodicSite',
    coords: 'ndarray',
    site_idx: 'int',
    site_labels: 'Any',
    _elem_colors: 'dict[str, str]',
    _atomic_radii: 'dict[str, float]',
    atom_size: 'float',
    scale: 'float',
    site_kwargs: 'dict[str, Any]',
    is_image: 'bool' = False,
    is_3d: 'bool' = False,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    **kwargs: 'Any'
) → None

Add a site (regular or image) to the plot.


function get_structures

get_structures(
    struct: 'Structure | Sequence[Structure] | Series | dict[Any, Structure]'
) → dict[Any, Structure]

Convert various input types to a dictionary of structures.


function draw_unit_cell

draw_unit_cell(
    fig: 'Figure',
    structure: 'Structure',
    unit_cell_kwargs: 'dict[str, Any]',
    is_3d: 'bool' = True,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None
) → Figure

Draw the unit cell of a structure in a 2D or 3D Plotly figure.


function draw_vector

draw_vector(
    fig: 'Figure',
    start: 'ndarray',
    vector: 'ndarray',
    is_3d: 'bool' = False,
    arrow_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → None

Add an arrow to represent a vector quantity on a Plotly figure.

This function adds an arrow to a 2D or 3D Plotly figure to represent a vector quantity. In 3D, it uses a cone for the arrowhead and a line for the shaft. In 2D, it uses a scatter plot with an arrow marker.

Args:

Note:

For 3D arrows, this function adds two traces to the figure: a cone for the arrowhead and a line for the shaft. For 2D arrows, it adds a single scatter trace with an arrow marker.


function get_first_matching_site_prop

get_first_matching_site_prop(
    structures: 'Sequence[Structure]',
    prop_keys: 'Sequence[str]',
    warn_if_none: 'bool' = True,
    filter_callback: 'Callable[[str, Any], bool] | None' = None
) → str | None

Find the first property key that exists in any of the passed structures' properties or site properties. Will look in site.properties first, then structure.properties.

Args:

Returns:


function draw_bonds

draw_bonds(
    fig: 'Figure',
    structure: 'Structure',
    nn: 'NearNeighbors',
    is_3d: 'bool' = True,
    bond_kwargs: 'dict[str, Any] | None' = None,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None,
    visible_image_atoms: 'set[tuple[float, float, float]] | None' = None
) → None

Draw bonds between atoms in the structure.

module structure_viz

2D and 3D plots of Structures.

module structure_viz.mpl

2D plots of pymatgen structures with matplotlib.

structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.

Global Variables


function structure_2d

structure_2d(
    struct: 'Structure | Sequence[Structure]',
    ax: 'Axes | None' = None,
    rotation: 'str' = '10x,8y,3z',
    atomic_radii: 'float | dict[str, float] | None' = None,
    elem_colors: 'ElemColorScheme | dict[str, str | ColorType]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    label_kwargs: 'dict[str, Any] | None' = None,
    bond_kwargs: 'dict[str, Any] | None' = None,
    standardize_struct: 'bool | None' = None,
    axis: 'bool | str' = 'off',
    n_cols: 'int' = 4,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    subplot_title: 'Callable[[Structure, str | int], str] | None' = None
) → Axes | tuple[Figure, ndarray[Axes]]

Plot pymatgen structures in 2D with matplotlib.

structure_2d is not deprecated but structure_(2d|3d)_plotly() have more features and are recommended replacements.

Inspired by ASE's ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib pymatviz aims to give similar output to ASE but supports disordered structures and avoids the conversion hassle of AseAtomsAdaptor.get_atoms(pmg_struct).

For example, these two snippets should give very similar output:

from pymatgen.ext.matproj import MPRester

mp_19017 = MPRester().get_structure_by_material_id("mp-19017")

# ASE
from ase.visualize.plot import plot_atoms
from pymatgen.io.ase import AseAtomsAdaptor

plot_atoms(AseAtomsAdaptor().get_atoms(mp_19017), rotation="10x,8y,3z", radii=0.5)

# pymatviz
from pymatviz import structure_2d

structure_2d(mp_19017)

Multiple structures in single figure example:

from pymatgen.ext.matproj import MPRester
from pymatviz import structure_2d

structures = {
     (mp_id := f"mp-{idx}"): MPRester().get_structure_by_material_id(mp_id)
     for idx in range(1, 5)
}
structure_2d(structures)

Args:

Raises:

Returns:

module structure_viz.plotly

Create interactive hoverable 2D and 3D plots of pymatgen structures with plotly.

Global Variables


function structure_2d_plotly

structure_2d_plotly(
    struct: 'Structure | Sequence[Structure]',
    rotation: 'str' = '10x,8y,3z',
    atomic_radii: 'float | dict[str, float] | None' = None,
    atom_size: 'float' = 40,
    elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool | dict[str, Any]' = True,
    show_sites: 'bool | dict[str, Any]' = True,
    show_image_sites: 'bool | dict[str, Any]' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    standardize_struct: 'bool | None' = None,
    n_cols: 'int' = 3,
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None' = None,
    show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
    vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    bond_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot pymatgen structures in 2D with Plotly.

Args:

Returns:


function structure_3d_plotly

structure_3d_plotly(
    struct: 'Structure | Sequence[Structure]',
    atomic_radii: 'float | dict[str, float] | None' = None,
    atom_size: 'float' = 20,
    elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool | dict[str, Any]' = True,
    show_sites: 'bool | dict[str, Any]' = True,
    show_image_sites: 'bool' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    standardize_struct: 'bool | None' = None,
    n_cols: 'int' = 3,
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None | Literal[False]' = None,
    show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
    vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    bond_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot pymatgen structures in 3D with Plotly.

Args:

Returns:

module sunburst

Hierarchical multi-level pie charts (i.e. sunbursts).

E.g. for crystal symmetry distributions.

Global Variables


function spacegroup_sunburst

spacegroup_sunburst(
    data: 'Sequence[int | str] | Series',
    show_counts: "Literal['value', 'percent', False]" = False,
    **kwargs: 'Any'
) → Figure

Generate a sunburst plot with crystal systems as the inner ring for a list of international space group numbers.

Hint: To hide very small labels, set a uniformtext minsize and mode='hide'. fig.update_layout(uniformtext=dict(minsize=9, mode="hide"))

Args:

Returns:

module templates

Define custom pymatviz templates (default styles) for plotly and matplotlib.

Global Variables


function set_plotly_template

set_plotly_template(
    template: "Literal['pymatviz_white', 'pymatviz_dark'] | str | Template"
) → None

Set the default plotly express and graph objects template.

Args:

Raises:

module uncertainty

Visualizations for assessing the quality of model uncertainty estimates.

Global Variables


function qq_gaussian

qq_gaussian(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True
) → Axes

Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets y_true and model predictions y_pred.

Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line.

The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution.

Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot

Args:

Returns:


function get_err_decay

get_err_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    n_rand: 'int' = 100
) → tuple[ArrayLike, ArrayLike]

Calculate the model's error curve as samples are excluded from the calculation based on their absolute error.

Use in combination with get_std_decay to see what the error drop curve would look like if model error and uncertainty were perfectly rank-correlated.

Args:

Returns:


function get_std_decay

get_std_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    y_std: 'ArrayLike'
) → ArrayLike

Calculate the drop in model error as samples are excluded from the calculation based on the model's uncertainty.

For model's able to estimate their own uncertainty well, meaning predictions of larger error are associated with larger uncertainty, the error curve should fall off sharply at first as the highest-error points are discarded and slowly towards the end where only small-error samples with little uncertainty remain.

Note that even perfect model uncertainties would not mean this error drop curve coincides exactly with the one returned by get_err_decay as in some cases the model may have made an accurate prediction purely by chance in which case the error is small yet a good uncertainty estimate would still be large, leading the same sample to be excluded at different x-axis locations and thus the get_std_decay curve to lie higher.

Args:

Returns:


function error_decay_with_uncert

error_decay_with_uncert(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    n_rand: 'int' = 100,
    percentiles: 'bool' = True,
    ax: 'Axes | None' = None
) → Axes

Plot for assessing the quality of uncertainty estimates. If a model's uncertainty is well calibrated, i.e. strongly correlated with its error, removing the most uncertain predictions should make the mean error decay similarly to how it decays when removing the predictions of largest error.

Args:

Note: If you're not happy with the default y_max of 1.1 * rand_mean, where rand_mean is mean of random sample exclusion, use ax.set(ylim=[None, some_value * ax.get_ylim()[1]]).

Returns:

module utils

pymatviz utility functions.

Global Variables


function pretty_label

pretty_label(key: 'str', backend: 'Backend') → str

Map metric keys to their pretty labels.


function crystal_sys_from_spg_num

crystal_sys_from_spg_num(spg: 'float') → CrystalSystem

Get the crystal system for an international space group number.


function df_to_arrays

df_to_arrays(
    df: 'DataFrame | None',
    *args: 'str | Sequence[str] | Sequence[ArrayLike]',
    strict: 'bool' = True
) → list[ArrayLike | dict[str, ArrayLike]]

If df is None, this is a no-op: args are returned as-is. If df is a dataframe, all following args are used as column names and the column data returned as arrays (after dropping rows with NaNs in any column).

Args:

Raises:

Returns:


function bin_df_cols

bin_df_cols(
    df_in: 'DataFrame',
    bin_by_cols: 'Sequence[str]',
    group_by_cols: 'Sequence[str]' = (),
    n_bins: 'int | Sequence[int]' = 100,
    bin_counts_col: 'str' = 'bin_counts',
    density_col: 'str' = '',
    verbose: 'bool' = True
) → DataFrame

Bin columns of a DataFrame.

Args:

Returns:


function patch_dict

patch_dict(
    dct: 'dict[Any, Any]',
    *args: 'Any',
    **kwargs: 'Any'
) → Generator[dict[Any, Any], None, None]

Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.

Useful e.g. for temporary plotly fig.layout mutations:

with patch_dict(fig.layout, showlegend=False): fig.write_image("plot.pdf")

Args:

Yields:


function luminance

luminance(color: 'str | tuple[float, float, float]') → float

Compute the luminance of a color as in https://stackoverflow.com/a/596243.

Args:

Returns:


function pick_bw_for_contrast

pick_bw_for_contrast(
    color: 'tuple[float, float, float] | str',
    text_color_threshold: 'float' = 0.7
) → Literal['black', 'white']

Choose black or white text color for a given background color based on luminance.

Args:

Returns:


function si_fmt

si_fmt(
    val: 'float',
    fmt: 'str' = '.1f',
    sep: 'str' = '',
    binary: 'bool' = False,
    decimal_threshold: 'float' = 0.01
) → str

Convert large numbers into human readable format using SI prefixes.

Supports binary (1024) and metric (1000) mode.

https://nist.gov/pml/weights-and-measures/metric-si-prefixes

Args:

Returns:


function get_cbar_label_formatter

get_cbar_label_formatter(
    cbar_label_fmt: 'str',
    values_fmt: 'str',
    values_show_mode: "Literal['value', 'fraction', 'percent', 'off']",
    sci_notation: 'bool',
    default_decimal_places: 'int' = 1
) → Formatter

Generate colorbar tick label formatter.

Work differently for different values_show_mode: - "value/fraction" mode: Use cbar_label_fmt (or values_fmt) as is. - "percent" mode: Get number of decimal places to keep from fmt string, for example 1 from ".1%".

Args:

Returns: PercentFormatter or FormatStrFormatter.


function html_tag

html_tag(
    text: 'str',
    tag: 'str' = 'span',
    style: 'str' = '',
    title: 'str' = ''
) → str

Wrap text in a span with custom style.

Style defaults to decreased font size and weight e.g. to display units in plotly labels and annotations.

Args:

Returns:


function validate_fig

validate_fig(func: 'Callable[P, R]') → Callable[P, R]

Decorator to validate the type of fig keyword argument in a function. fig MUST be a keyword argument, not a positional argument.


function annotate

annotate(text: 'str | Sequence[str]', fig: 'AxOrFig', **kwargs: 'Any') → AxOrFig

Annotate a matplotlib or plotly figure. Supports faceted plots plotly figure with trace with empty strings skipped.

Args:

Returns:

Raises:


function get_fig_xy_range

get_fig_xy_range(
    fig: 'Figure | Figure | Axes',
    trace_idx: 'int' = 0
) → tuple[tuple[float, float], tuple[float, float]]

Get the x and y range of a plotly or matplotlib figure.

Args:

Returns:


function get_font_color

get_font_color(fig: 'AxOrFig') → str

Get the font color used in a Matplotlib figure/axes or a Plotly figure.

Args:

Returns:

Raises:


function normalize_to_dict

normalize_to_dict(
    inputs: 'T | Sequence[T] | dict[str, T]',
    cls: 'type[T]' = <class 'pymatgen.core.structure.Structure'>,
    key_gen: 'Callable[[T], str]' = <function <lambda> at 0x7fb273623b00>
) → dict[str, T]

Normalize input to a dictionary of objects.

Args:

Returns: A dictionary of objects with keys as object formulas or given keys.

Raises:


class ExperimentalWarning

Warning for experimental features.

module xrd

Module for plotting XRD patterns using plotly.

Global Variables


function format_hkl

format_hkl(hkl: 'tuple[int, int, int]', format_type: 'HklFormat') → str

Format hkl indices as a string.

Args:

Raises:


function xrd_pattern

xrd_pattern(
    patterns: 'PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct, dict[str, Any]]]',
    peak_width: 'float' = 0.5,
    annotate_peaks: 'float' = 5,
    hkl_format: 'HklFormat' = 'compact',
    show_angles: 'bool | None' = None,
    wavelength: 'float' = 1.54184,
    stack: "Literal['horizontal', 'vertical'] | None" = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    subtitle_kwargs: 'dict[str, Any] | None' = None
) → Figure

Create a plotly figure of XRD patterns from a pymatgen DiffractionPattern, from a pymatgen Structure, or a dictionary of either of them.

Args: patterns (PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct,

Raises:

Returns: