« home

API

module bar

Bar plots.

Global Variables


function spacegroup_bar

spacegroup_bar(
    data: 'Sequence[int | str | Structure] | Series',
    show_counts: 'bool' = True,
    xticks: "Literal['all', 'crys_sys_edges'] | int" = 20,
    show_empty_bins: 'bool' = False,
    ax: 'Axes | None' = None,
    backend: 'Backend' = 'plotly',
    text_kwargs: 'dict[str, Any] | None' = None,
    log: 'bool' = False,
    **kwargs: 'Any'
) → Axes | Figure

Plot a histogram of spacegroups shaded by crystal system.

Args:

Returns:

module colors

Colors used in pymatviz.

Global Variables

module coordination.helpers

Helper functions for calculating coordination numbers.

Global Variables


function create_hover_text

create_hover_text(
    struct_key: 'str',
    elem_symbol: 'str',
    cn: 'int',
    count: 'int',
    hover_data: 'dict[str, str]',
    data: 'dict[str, Any]',
    is_single_structure: 'bool'
) → str

Create hover text for a single bar in the histogram.


function normalize_get_neighbors

normalize_get_neighbors(
    strategy: 'float | NearNeighbors | type[NearNeighbors]'
) → Callable[[PeriodicSite, Structure], list[dict[str, Any]]]

Normalize get_neighbors function.


function calculate_average_cn

calculate_average_cn(
    structure: 'Structure',
    element: 'str',
    get_neighbors: 'Callable[[PeriodicSite, Structure], list[dict[str, Any]]]'
) → float

Calculate the average coordination number for a given element in a structure.


function coordination_nums_in_structure

coordination_nums_in_structure(
    structure: 'Structure',
    strategy: 'float | NearNeighbors | type[NearNeighbors]' = 3.0,
    group_by: "Literal['element', 'specie', 'site']" = 'element'
) → dict[str, list[int]]

Get coordination numbers (CN) for each element in a structure.

Args:

Returns:

Example: from pymatgen.core import Structure >>> structure = Structure.from_file("SiO2.cif") >>> cns = coordination_nums_in_structure(structure) >>> print(cns) {"Si": [4, 4, 4], "O": [2, 2, 2, 2, 2, 2]}


class CnSplitMode

How to split the coordination number histogram into subplots.

module coordination

Coordination number analysis and visualization.

module coordination.plotly

Plotly plots of coordination numbers distributions.

Global Variables


function coordination_hist

coordination_hist(
    structures: 'Structure | dict[str, Structure] | Sequence[Structure]',
    strategy: 'float | NearNeighbors | type[NearNeighbors]' = 3.0,
    split_mode: 'CnSplitMode | str' = CnSplitMode.by_element,
    bar_mode: "Literal['group', 'stack']" = 'stack',
    hover_data: 'Sequence[str] | dict[str, str] | None' = None,
    element_color_scheme: 'ElemColorScheme | dict[str, str]' = Jmol,
    annotate_bars: 'bool | dict[str, Any]' = False,
    bar_kwargs: 'dict[str, Any] | None' = None
) → Figure

Create a plotly histogram of coordination numbers for given structure(s).

Args:

Returns: A plotly Figure object containing the histogram.


function coordination_vs_cutoff_line

coordination_vs_cutoff_line(
    structures: 'Structure | dict[str, Structure] | Sequence[Structure]',
    strategy: 'tuple[float, float] | NearNeighbors | type[NearNeighbors]' = (1, 5),
    num_points: 'int' = 50,
    element_color_scheme: 'ElemColorScheme | dict[str, str]' = Jmol,
    subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure

Create a plotly line plot of cumulative coordination numbers vs cutoff distance.

Args:

Returns: A plotly Figure object containing the line plot.

module cumulative

Plot the cumulative distribution of residuals and absolute errors.

Global Variables


function cumulative_residual

cumulative_residual(
    res: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution for the residuals (y - mu).

Args:

Returns:


function cumulative_error

cumulative_error(
    abs_err: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution of abs(y_true - y_pred).

Args:

Returns:

module data

Generate dummy data for testing and prototyping.


function regression

regression(
    n_samples: int = 500,
    true_mean: float = 5,
    true_std: float = 4,
    pred_slope: float = 1.2,
    pred_intercept: float = -2,
    seed: int = 0
) → RegressionData

Generate dummy regression data for testing and prototyping.

This function creates synthetic data to simulate a regression task:

Parameters:

Returns:


class RegressionData

Regression data containing: y_true, y_pred and y_std.

module enums

Enums used as keys/accessors for dicts and dataframes across Matbench Discovery.

Global Variables


class LabelEnum

StrEnum with optional label and description attributes plus dict() methods.

Simply add label and description as a tuple starting with the key's value.


class Key

Keys used to access dataframes columns, organized by semantic groups.


class Task

What kind of task is being performed.


class Model

Model names.


class ElemCountMode

Mode of counting elements in a chemical formula.


class ElemColorMode

Mode of coloring elements in structure visualizations or periodic table plots.


class ElemColorScheme

Names of element color palettes.

Used e.g. in structure visualizations and periodic table plots.


class SiteCoords

Site coordinate representations.

module histogram

Histograms.

Global Variables


function spacegroup_hist

spacegroup_hist(*args: 'Any', **kwargs: 'Any') → Axes | Figure

Alias for spacegroup_bar.


function elements_hist

elements_hist(
    formulas: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    log: 'bool' = False,
    keep_top: 'int | None' = None,
    ax: 'Axes | None' = None,
    bar_values: "Literal['percent', 'count'] | None" = 'percent',
    h_offset: 'int' = 0,
    v_offset: 'int' = 10,
    rotation: 'int' = 45,
    fontsize: 'int' = 12,
    **kwargs: 'Any'
) → Axes

Plot a histogram of elements (e.g. to show occurrence in a dataset).

Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).

Args:

Returns:


function histogram

histogram(
    values: 'Sequence[float] | dict[str, Sequence[float]]',
    bins: 'int | Sequence[float] | str' = 200,
    x_range: 'tuple[float | None, float | None] | None' = None,
    density: 'bool' = False,
    bin_width: 'float' = 1.2,
    log_y: 'bool' = False,
    backend: 'Backend' = 'plotly',
    fig_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Figure | Figure

Get a histogram with plotly (default) or matplotlib backend but using fast numpy pre-processing before handing the data off to the plot function.

Very common use case when dealing with large datasets so worth having a dedicated function for it. Two advantages over the matplotlib/plotly native histograms are much faster and much smaller file sizes (when saving plotly figs as HTML since plotly saves a complete copy of the data to disk from which it recomputes the histogram on the fly to render the figure). Speedup example:

gaussian = np.random.normal(0, 1, 1_000_000_000) plot_histogram(gaussian) # takes 17s px.histogram(gaussian) # ran for 3m45s before crashing the Jupyter kernel

Args:

Returns:

module io

I/O utilities for saving figures and dataframes to various image formats.

Global Variables


function save_fig

save_fig(
    fig: 'Figure | Figure | Axes',
    path: 'str',
    plotly_config: 'dict[str, Any] | None' = None,
    env_disable: 'Sequence[str]' = ('CI',),
    pdf_sleep: 'float' = 0.6,
    style: 'str' = '',
    prec: 'int | None' = None,
    template: 'str | None' = None,
    transparent_bg: 'bool' = True,
    **kwargs: 'Any'
) → None

Write a plotly or matplotlib figure to disk (as HTML/PDF/SVG/...).

If the file is has .svelte extension, insert {...$$props} into the figure's top-level div so it can be later styled and customized from Svelte code.

Args:


function save_and_compress_svg

save_and_compress_svg(
    fig: 'Figure | Figure | Axes',
    filename: 'str',
    transparent_bg: 'bool' = True
) → None

Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with svgo CLI if available in PATH.

If filename does not include .svg extension and is not absolute, will be treated as relative to assets/ folder. This function is mostly meant for pymatviz internal use.

Args:

Raises:


function df_to_pdf

df_to_pdf(
    styler: 'Styler',
    file_path: 'str | Path',
    crop: 'bool' = True,
    size: 'str | None' = None,
    style: 'str' = '',
    styler_css: 'bool | dict[str, str]' = True,
    **kwargs: 'Any'
) → None

Export a pandas Styler to PDF with WeasyPrint.

Args:


function normalize_and_crop_pdf

normalize_and_crop_pdf(
    file_path: 'str | Path',
    on_gs_not_found: "Literal['ignore', 'warn', 'error']" = 'warn'
) → None

Normalize a PDF using Ghostscript and then crop it. Without gs normalization, pdfCropMargins sometimes corrupts the PDF.

Args:


function df_to_html_table

df_to_html_table(*args: 'Any', **kwargs: 'Any') → str

function df_to_html

df_to_html(
    styler: 'Styler',
    file_path: 'str | Path | None' = None,
    inline_props: 'str | None' = '{{...$$props}} ',
    pre_table: 'str | None' = '',
    styles: 'str | None' = 'table { overflow: scroll; max-width: 100%; display: block; }\ntable {\n    scrollbar-width: none;  /* Firefox */\n}\ntable::-webkit-scrollbar {\n    display: none;  /* Safari and Chrome */\n}',
    styler_css: 'bool | dict[str, str]' = True,
    use_sortable: 'bool' = True,
    use_tooltips: 'bool' = True,
    post_process: 'Callable[[str], str] | None' = None,
    **kwargs: 'Any'
) → str

Convert a pandas Styler to a svelte table.

Args:

Returns:


function df_to_svg

df_to_svg(
    obj: 'DataFrame | Styler',
    file_path: 'str | Path',
    font_size: 'int' = 14,
    compress: 'bool' = True,
    **kwargs: 'Any'
) → Figure

Export a pandas DataFrame or Styler to an SVG file and optionally compress it.

TODO The SVG output has annoying margins that proved hard to remove. The goal is for this function to auto-crop the SVG viewBox to the content in the future.

Args:

Returns:

Raises:


class BufferedIOBase

Base class for buffered IO objects.

The main difference with RawIOBase is that the read() method supports omitting the size argument, and does not have a default implementation that defers to readinto().

In addition, read(), readinto() and write() may raise BlockingIOError if the underlying raw stream is in non-blocking mode and not ready; unlike their raw counterparts, they will never return None.

A typical implementation should not inherit from a RawIOBase implementation, but wrap one.


class IOBase

The abstract base class for all I/O classes.

This class provides dummy implementations for many methods that derived classes can override selectively; the default implementations represent a file that cannot be read, written or seeked.

Even though IOBase does not declare read, readinto, or write because their signatures will vary, implementations and clients should consider those methods part of the interface. Also, implementations may raise UnsupportedOperation when operations they do not support are called.

The basic type used for binary data read from or written to a file is bytes. Other bytes-like objects are accepted as method arguments too. In some cases (such as readinto), a writable object is required. Text I/O classes work with str data.

Note that calling any method (except additional calls to close(), which are ignored) on a closed stream should raise a ValueError.

IOBase (and its subclasses) support the iterator protocol, meaning that an IOBase object can be iterated over yielding the lines in a stream.

IOBase also supports the :keyword:with statement. In this example, fp is closed after the suite of the with statement is complete:

with open('spam.txt', 'r') as fp: fp.write('Spam and eggs!')


class RawIOBase

Base class for raw binary I/O.


class TextIOBase

Base class for text I/O.

This class provides a character and line based interface to stream I/O. There is no readinto method because Python's character strings are immutable.


class UnsupportedOperation


class TqdmDownload

Progress bar for urllib.request.urlretrieve file download.

Adapted from official TqdmUpTo example. See https://github.com/tqdm/tqdm/blob/4c956c20b83be4312460fc0c4812eeb3fef5e7df/README.rst#hooks-and-callbacks

method __init__

__init__(*args: 'Any', **kwargs: 'Any') → None

Sets default values appropriate for file downloads for unit, unit_scale, unit_divisor, miniters, desc.


property format_dict

Public API for read-only member access.


method update_to

update_to(
    n_blocks: 'int' = 1,
    block_size: 'int' = 1,
    total_size: 'int | None' = None
) → bool | None

Update hook for urlretrieve.

Args:

Returns:

module phonons

Plotting functions for pymatgen phonon band structures and density of states.

Global Variables


function pretty_sym_point

pretty_sym_point(symbol: 'str') → str

Convert a symbol to a pretty-printed version.


function get_band_xaxis_ticks

get_band_xaxis_ticks(
    band_struct: 'PhononBands',
    branches: 'Sequence[str] | set[str]' = ()
) → tuple[list[float], list[str]]

Get all ticks and labels for a band structure plot.

Returns:


function phonon_bands

phonon_bands(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    line_kwargs: "dict[str, Any] | dict[Literal['acoustic', 'optical'], dict[str, Any]] | Callable[[ndarray, int], dict[str, Any]] | None" = None,
    branches: 'Sequence[str]' = (),
    branch_mode: 'BranchMode' = 'union',
    shaded_ys: 'dict[tuple[YMin | YMax, YMin | YMax], dict[str, Any]] | bool | None' = None,
    **kwargs: 'Any'
) → Figure

Plot single or multiple pymatgen band structures using Plotly.

Warning: Only tested with phonon band structures so far but plan is to extend to electronic band structures.

Args: band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.

Returns:


function phonon_dos

phonon_dos(
    doses: 'PhononDos | dict[str, PhononDos]',
    stack: 'bool' = False,
    sigma: 'float' = 0,
    units: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz',
    normalize: "Literal['max', 'sum', 'integral'] | None" = None,
    last_peak_anno: 'str | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS using Plotly.

Args:

Returns:


function convert_frequencies

convert_frequencies(
    frequencies: 'ndarray',
    unit: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz'
) → ndarray

Convert frequencies from THz to specified units.

Args:

Returns:


function phonon_bands_and_dos

phonon_bands_and_dos(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    doses: 'PhononDos | dict[str, PhononDos]',
    bands_kwargs: 'dict[str, Any] | None' = None,
    dos_kwargs: 'dict[str, Any] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    all_line_kwargs: 'dict[str, Any] | None' = None,
    per_line_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS and band structure using Plotly.

Args:

Returns:


class PhononDBDoc

Dataclass for phonon DB docs.

method __init__

__init__(
    structure: 'Structure',
    phonon_bandstructure: 'PhononBands',
    phonon_dos: 'PhononDos',
    free_energies: 'list[float]',
    internal_energies: 'list[float]',
    heat_capacities: 'list[float]',
    entropies: 'list[float]',
    temps: 'list[float] | None' = None,
    has_imaginary_modes: 'bool | None' = None,
    primitive: 'Structure | None' = None,
    supercell: 'list[list[int]] | None' = None,
    nac_params: 'dict[str, Any] | None' = None,
    thermal_displacement_data: 'dict[str, Any] | None' = None,
    mp_id: 'str | None' = None,
    formula: 'str | None' = None
) → None

module powerups.both

Powerups that can be applied to both matplotlib and plotly figures.

Global Variables


function annotate_metrics

annotate_metrics(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    fig: 'AxOrFig | None' = None,
    metrics: 'dict[str, float] | Sequence[str]' = ('MAE', 'R2'),
    prefix: 'str' = '',
    suffix: 'str' = '',
    fmt: 'str' = '.3',
    **kwargs: 'Any'
) → AnchoredText

Provide a set of x and y values of equal length and an optional Axes object on which to print the values' mean absolute error and R^2 coefficient of determination.

Args:

Returns:


function add_identity_line

add_identity_line(
    fig: 'Figure | Figure | Axes',
    line_kwargs: 'dict[str, Any] | None' = None,
    trace_idx: 'int' = 0,
    retain_xy_limits: 'bool' = False,
    **kwargs: 'Any'
) → Figure

Add a line shape to the background layer of a plotly figure spanning from smallest to largest x/y values in the trace specified by trace_idx.

Args:

Raises:

Returns:


function add_best_fit_line

add_best_fit_line(
    fig: 'Figure | Figure | Axes',
    xs: 'ArrayLike' = (),
    ys: 'ArrayLike' = (),
    trace_idx: 'int | None' = None,
    line_kwargs: 'dict[str, Any] | None' = None,
    annotate_params: 'bool | dict[str, Any]' = True,
    warn: 'bool' = True,
    **kwargs: 'Any'
) → Figure

Add line of best fit according to least squares to a plotly or matplotlib figure.

Args:

Raises:

Returns:

module powerups.matplotlib

Powerups for matplotlib figures.

Global Variables


function with_marginal_hist

with_marginal_hist(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    fig: 'Figure | Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Call before creating a matplotlib figure and use the returned ax_main for all subsequent plotting ops to create a grid of plots with the main plot in the lower left and narrow histograms along its x- and/or y-axes displayed above and near the right edge.

Args:

Returns:


function annotate_bars

annotate_bars(
    ax: 'Axes | None' = None,
    v_offset: 'float' = 10,
    h_offset: 'float' = 0,
    labels: 'Sequence[str | int | float] | None' = None,
    fontsize: 'int' = 14,
    y_max_headroom: 'float' = 1.2,
    adjust_test_pos: 'bool' = False,
    **kwargs: 'Any'
) → None

Annotate each bar in bar plot with a label.

Args:

module powerups

Powerups such as parity lines, annotations, marginals, menu buttons, etc for matplotlib and plotly figures.

Global Variables

module powerups.plotly

Powerups for plotly figures.

Global Variables


function add_ecdf_line

add_ecdf_line(
    fig: 'Figure',
    values: 'ArrayLike' = (),
    trace_idx: 'int' = 0,
    trace_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Figure

Add an empirical cumulative distribution function (ECDF) line to a plotly figure.

Support for matplotlib planned but not implemented. PRs welcome.

Args:

Returns:

module process_data

pymatviz utility functions.

Global Variables


function count_elements

count_elements(
    values: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    exclude_elements: 'Sequence[str]' = (),
    fill_value: 'float | None' = None
) → Series

Count element occurrence in list of formula strings or dict-like compositions.

If passed values are already a map from element symbol to counts, ensure the data is a pd.Series filled with "fill_value" for missing element.

Provided as standalone function for external use or to cache long computations. Caching long element counts is done by refactoring: ptable_heatmap(long_list_of_formulas) # slow to: elem_counts = count_elements(long_list_of_formulas) # slow ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot

Args:

Returns:

module ptable

matplotlib and plotly periodic table figures.

module ptable.ptable_matplotlib

Various periodic table heatmaps with matplotlib and plotly.

Global Variables


function ptable_heatmap

ptable_heatmap(
    data: 'DataFrame | Series | dict[str, list[list[float]]] | PTableData',
    colormap: 'str' = 'viridis',
    exclude_elements: 'Sequence[str]' = (),
    overwrite_tiles: 'dict[ElemStr, OverwriteTileValueColor] | None' = None,
    infty_color: 'ColorType' = 'lightskyblue',
    nan_color: 'ColorType' = 'lightgrey',
    log: 'bool' = False,
    sci_notation: 'bool' = False,
    tile_size: 'tuple[float, float]' = (0.75, 0.75),
    on_empty: "Literal['hide', 'show']" = 'show',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    f_block_voffset: 'float' = 0,
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    text_colors: "Literal['auto'] | ColorType | dict[ElemStr, ColorType]" = 'auto',
    symbol_pos: 'tuple[float, float] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    value_show_mode: "Literal['value', 'fraction', 'percent', 'off']" = 'value',
    value_pos: 'tuple[float, float] | None' = None,
    value_fmt: 'str' = 'auto',
    value_kwargs: 'dict[str, Any] | None' = None,
    show_cbar: 'bool' = True,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.05),
    cbar_range: 'tuple[float | None, float | None]' = (None, None),
    cbar_label_fmt: 'str' = 'auto',
    cbar_title: 'str' = 'Element Count',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_kwargs: 'dict[str, Any] | None' = None,
    return_type: "Literal['figure', 'axes']" = 'axes',
    colorscale: 'str | None' = None,
    heat_mode: "Literal['value', 'fraction', 'percent'] | None" = None,
    show_values: 'bool | None' = None,
    fmt: 'str | None' = None,
    cbar_fmt: 'str | None' = None,
    show_scale: 'bool | None' = None
) → Axes

Plot a heatmap across the periodic table.

Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

Returns:


function ptable_heatmap_ratio

ptable_heatmap_ratio(
    values_num: 'ElemValues',
    values_denom: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    normalize: 'bool' = False,
    infty_color: 'ColorType' = 'lightskyblue',
    zero_color: 'ColorType' = 'lightgrey',
    zero_tol: 'float' = 1e-06,
    zero_symbol: 'str' = 'ZERO',
    not_in_numerator: 'tuple[str, str] | None' = ('lightgray', 'gray: not in 1st list'),
    not_in_denominator: 'tuple[str, str] | None' = ('lightskyblue', 'blue: not in 2nd list'),
    not_in_either: 'tuple[str, str] | None' = ('white', 'white: not in either'),
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Element Ratio',
    **kwargs: 'Any'
) → figure

Display the ratio of two maps from element symbols to heat values or of two sets of compositions.

Args:

--- Ratio heatmap specific --- 

Returns:


function ptable_heatmap_splits

ptable_heatmap_splits(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    start_angle: 'float' = 135,
    colormap: 'str | Colormap' = 'viridis',
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda104237e0>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.5),
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_text_color: 'ColorType | dict[ElemStr, ColorType]' = 'black',
    anno_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot evenly-split heatmaps, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

Notes:

Default figsize is set to (0.75 * n_groups, 0.75 * n_periods).

Returns:


function ptable_hists

ptable_hists(
    data: 'DataFrame | Series | dict[ElemStr, list[float]]',
    bins: 'int' = 20,
    x_range: 'tuple[float | None, float | None] | None' = None,
    log: 'bool' = False,
    colormap: 'str | Colormap | None' = 'viridis',
    on_empty: "Literal['show', 'hide']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    cbar_axis: "Literal['x', 'y']" = 'x',
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423920>,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot histograms for each element laid out in a periodic table.

Args:

Returns:


function ptable_scatters

ptable_scatters(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    colormap: 'str | Colormap | None' = None,
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    cbar_title: 'str' = 'Values',
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423a60>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    symbol_kwargs: 'dict[str, Any] | None' = None,
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Make scatter plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,


function ptable_lines

ptable_lines(
    data: 'DataFrame | Series | dict[ElemStr, list[list[float]]]',
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    plot_kwargs: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    child_kwargs: 'dict[str, Any] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7fda10423ba0>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    anno_pos: 'tuple[float, float]' = (0.75, 0.75),
    anno_text: 'dict[ElemStr, str] | None' = None,
    anno_kwargs: 'dict[str, Any] | None' = None,
    color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    add_elem_type_legend: 'bool' = False,
    elem_type_legend_kwargs: 'dict[str, Any] | None' = None
) → Figure

Line plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[ElemStr, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

module ptable.ptable_plotly

Periodic table plots powered by plotly.

Global Variables


function ptable_heatmap_plotly

ptable_heatmap_plotly(
    values: 'ElemValues',
    count_mode: 'ElemCountMode' = Composition,
    colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
    show_scale: 'bool' = True,
    show_values: 'bool' = True,
    heat_mode: "Literal['value', 'fraction', 'percent']" = 'value',
    fmt: 'str | None' = None,
    hover_props: 'Sequence[str] | dict[str, str] | None' = None,
    hover_data: 'dict[str, str | int | float] | Series | None' = None,
    font_colors: 'Sequence[str]' = (),
    gap: 'float' = 5,
    font_size: 'int | None' = None,
    bg_color: 'str | None' = None,
    nan_color: 'str' = '#eff',
    colorbar: 'dict[str, Any] | None' = None,
    cscale_range: 'tuple[float | None, float | None]' = (None, None),
    exclude_elements: 'Sequence[str]' = (),
    log: 'bool' = False,
    fill_value: 'float | None' = None,
    element_symbol_map: 'dict[str, str] | None' = None,
    label_map: 'dict[str, str] | Callable[[str], str] | Literal[False] | None' = None,
    border: 'dict[str, Any] | None | Literal[False]' = None,
    scale: 'float' = 1.0,
    **kwargs: 'Any'
) → Figure

Create a Plotly figure with an interactive heatmap of the periodic table. Supports hover tooltips with custom data or atomic reference data like electronegativity, atomic_radius, etc. See kwargs hover_data and hover_props, resp.

Args:

Returns:


function ptable_hists_plotly

ptable_hists_plotly(
    data: 'DataFrame | Series | dict[str, list[float]]',
    bins: 'int' = 20,
    x_range: 'tuple[float | None, float | None] | None' = None,
    log: 'bool' = False,
    colorscale: 'str' = 'RdBu',
    colorbar: 'dict[str, Any] | Literal[False] | None' = None,
    font_size: 'int | None' = None,
    scale: 'float' = 1.0,
    element_symbol_map: 'dict[str, str] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    annotations: 'dict[str, str | dict[str, Any]] | Callable[[Sequence[float]], str | dict[str, Any] | list[dict[str, Any]]] | None' = None,
    color_elem_strategy: 'ColorElemTypeStrategy' = 'background',
    elem_type_colors: 'dict[str, str] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    x_axis_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plotly figure with histograms for each element laid out in a periodic table.

Args:

Returns:


function ptable_heatmap_splits_plotly

ptable_heatmap_splits_plotly(
    data: 'DataFrame | Series | dict[str, list[float]]',
    orientation: "Literal['diagonal', 'horizontal', 'vertical', 'grid']" = 'diagonal',
    colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
    colorbar: 'dict[str, Any] | Literal[False] | None' = None,
    on_empty: "Literal['hide', 'show']" = 'hide',
    hide_f_block: "bool | Literal['auto']" = 'auto',
    font_size: 'int | None' = None,
    scale: 'float' = 1.0,
    element_symbol_map: 'dict[str, str] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    annotations: 'dict[str, str | dict[str, Any]] | Callable[[ndarray], str | dict[str, Any]] | None' = None,
    nan_color: 'str' = '#eff',
    hover_data: 'dict[str, str | int | float] | Series | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure

Create a Plotly figure with an interactive heatmap of the periodic table, where each element tile is split into sections representing different values.

Args:

Returns:

Raises:

module rainclouds

Raincloud plots.

Global Variables


function rainclouds

rainclouds(
    data: 'dict[str, Sequence[float] | tuple[DataFrame, str]]',
    orientation: "Literal['h', 'v']" = 'h',
    alpha: 'float' = 0.7,
    width_viol: 'float' = 0.3,
    width_box: 'float' = 0.05,
    jitter: 'float' = 0.01,
    point_size: 'float' = 3,
    bw: 'float' = 0.2,
    cut: 'float' = 0.0,
    scale: "Literal['area', 'count', 'width']" = 'area',
    rain_offset: 'float' = -0.25,
    offset: 'float | None' = None,
    hover_data: 'Sequence[str] | dict[str, Sequence[str]] | None' = None,
    show_violin: 'bool' = True,
    show_box: 'bool' = True,
    show_points: 'bool' = True
) → Figure

Create a raincloud plot for multiple datasets using Plotly.

This plot type was proposed in https://wellcomeopenresearch.org/articles/4-63/v2. It is a vertical stack of:

  1. violin plot (the cloud) 2. box plot (the umbrella) 3. strip plot (the rain)

Args:

Returns:

module rdf.helpers

Helper functions for radial distribution functions (RDFs) of pymatgen structures.

Global Variables


function calculate_rdf

calculate_rdf(
    structure: 'Structure',
    center_species: 'str | None' = None,
    neighbor_species: 'str | None' = None,
    cutoff: 'float' = 15,
    n_bins: 'int' = 75,
    pbc: 'PbcLike' = (True, True, True)
) → tuple[ndarray, ndarray]

Calculate the radial distribution function (RDF) for a given structure.

If center_species and neighbor_species are provided, calculates the partial RDF for the specified element pair. Otherwise, calculates the full RDF.

Args:

Returns:

module rdf

Radial distribution function plots (RDFs).

module rdf.plotly

Radial distribution functions (RDFs) of pymatgen structures using plotly.

The main function, pairwise_rdfs, generates a plotly figure with facets for each pair of elements in the given structure. It supports customization of cutoff distance, bin size, specific element pairs to plot, reference line.

Example usage: structure = Structure(...) # Create or load a pymatgen Structure fig = pairwise_rdfs(structure, bin_size=0.1) fig.show()

Global Variables


function element_pair_rdfs

element_pair_rdfs(
    structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
    cutoff: 'float | None' = None,
    n_bins: 'int' = 75,
    bin_size: 'float | None' = None,
    element_pairs: 'list[tuple[str, str]] | None' = None,
    reference_line: 'dict[str, Any] | None' = None,
    colors: 'Sequence[str] | None' = None,
    line_styles: 'Sequence[str] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None
) → Figure

Generate a plotly figure of pairwise radial distribution functions (RDFs) for all (or a subset of) element pairs in one or multiple structures.

Args:

Returns:

Raises:


function full_rdf

full_rdf(
    structures: 'Structure | Sequence[Structure] | dict[str, Structure]',
    cutoff: 'float' = 15,
    n_bins: 'int' = 75,
    bin_size: 'float | None' = None,
    reference_line: 'dict[str, Any] | None' = None,
    colors: 'Sequence[str] | None' = None,
    line_styles: 'Sequence[str] | None' = None
) → Figure

Generate a plotly figure of full radial distribution functions (RDFs) for one or multiple structures.

Args:

Returns:

Raises:

module relevance

Plots for evaluating classifier performance.

Global Variables


function roc_curve

roc_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
) → tuple[float, Axes]

Plot the receiver operating characteristic curve of a binary classifier given target labels and predicted probabilities for the positive class.

Args:

Returns:


function precision_recall_curve

precision_recall_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
) → tuple[float, Axes]

Plot the precision recall curve of a binary classifier.

Args:

Returns:

module sankey

Sankey diagram for comparing distributions in two dataframe columns.

Global Variables


function sankey_from_2_df_cols

sankey_from_2_df_cols(
    df: 'DataFrame',
    cols: 'list[str]',
    labels_with_counts: "bool | Literal['percent']" = True,
    annotate_columns: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Figure

Plot two columns of a dataframe as a Plotly Sankey diagram.

Args:

Raises:

Returns:

module scatter

Parity, residual and density plots.

Global Variables


function density_scatter

density_scatter(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    log_density: 'bool' = True,
    hist_density_kwargs: 'dict[str, Any] | None' = None,
    color_bar: 'bool | dict[str, Any]' = True,
    xlabel: 'str | None' = None,
    ylabel: 'str | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    stats: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Axes

Scatter plot colored by density using matplotlib backend.

Args:

Returns:


function density_scatter_plotly

density_scatter_plotly(
    df: 'DataFrame',
    x: 'str',
    y: 'str',
    density: "Literal['kde', 'empirical'] | None" = None,
    log_density: 'bool | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any] | None' = None,
    stats: 'bool | dict[str, Any]' = True,
    n_bins: 'int | None | Literal[False]' = None,
    bin_counts_col: 'str | None' = None,
    facet_col: 'str | None' = None,
    **kwargs: 'Any'
) → Figure

Scatter plot colored by density using plotly backend.

This function uses binning as implemented in bin_df_cols() to reduce the number of points plotted which enables plotting millions of data points and reduced file size for interactive plots. All outlier points will be plotted as is but overlapping points (tolerance for overlap determined by n_bins) will be merged into a single point with a new column bin_counts_col counting the number of points in that bin.

Args:

Returns:


function scatter_with_err_bar

scatter_with_err_bar(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    xerr: 'ArrayLike | None' = None,
    yerr: 'ArrayLike | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    title: 'str | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty.

Args:

Returns:


function density_hexbin

density_hexbin(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    weights: 'ArrayLike | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    cbar_label: 'str | None' = 'Density',
    cbar_coords: 'tuple[float, float, float, float]' = (0.95, 0.03, 0.03, 0.7),
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights.

Args:

Returns:


function density_scatter_with_hist

density_scatter_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot colored (and optionally sorted) by density with histograms along each dimension.


function density_hexbin_with_hist

density_hexbin_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by density or by third dimension passed color_by with histograms along each dimension.


function residual_vs_actual

residual_vs_actual(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    xlabel: 'str' = 'Actual value',
    ylabel: 'str' = 'Residual ($y_\\mathrm{true} - y_\\mathrm{pred}$)',
    **kwargs: 'Any'
) → Axes

Plot targets on the x-axis vs residuals (y_err = y_true - y_pred) on the y-axis.

Args:

Returns:

module structure_viz.helpers

Helper functions for 2D and 3D plots of pymatgen structures with plotly.

Global Variables


function get_image_sites

get_image_sites(
    site: 'PeriodicSite',
    lattice: 'Lattice',
    tol: 'float' = 0.02
) → ndarray

Get images for a given site in a lattice.

Images are sites that are integer translations of the given site that are within a tolerance of the unit cell edges.

Args:

Returns:


function unit_cell_to_lines

unit_cell_to_lines(cell: 'ArrayLike') → tuple[ArrayLike, ArrayLike, ArrayLike]

Convert lattice vectors to plot lines.

Args:

Returns: tuple[np.array, np.array, np.array]: - Lines - z-indices that sort plot elements into out-of-plane layers - lines used to plot the unit cell


function get_elem_colors

get_elem_colors(
    elem_colors: 'ElemColorScheme | dict[str, str]'
) → dict[str, str]

Get element colors based on the provided scheme or custom dictionary.


function get_atomic_radii

get_atomic_radii(
    atomic_radii: 'float | dict[str, float] | None'
) → dict[str, float]

Get atomic radii based on the provided input.


function generate_site_label

generate_site_label(
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]",
    site_idx: 'int',
    majority_species: 'Species'
) → str

Generate a label for a site based on the provided labeling scheme.


function get_subplot_title

get_subplot_title(
    struct_i: 'Structure',
    struct_key: 'Any',
    idx: 'int',
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None'
) → dict[str, Any]

Generate a subplot title based on the provided function or default logic.


function get_site_hover_text

get_site_hover_text(
    site: 'PeriodicSite',
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]',
    majority_species: 'Species'
) → str

Generate hover text for a site based on the hover template.


function draw_site

draw_site(
    fig: 'Figure',
    site: 'PeriodicSite',
    coords: 'ndarray',
    site_idx: 'int',
    site_labels: 'Any',
    _elem_colors: 'dict[str, str]',
    _atomic_radii: 'dict[str, float]',
    atom_size: 'float',
    scale: 'float',
    site_kwargs: 'dict[str, Any]',
    is_image: 'bool' = False,
    is_3d: 'bool' = False,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    **kwargs: 'Any'
) → None

Add a site (regular or image) to the plot.


function get_structures

get_structures(
    struct: 'Structure | Sequence[Structure] | Series | dict[Any, Structure]'
) → dict[Any, Structure]

Convert various input types to a dictionary of structures.


function draw_unit_cell

draw_unit_cell(
    fig: 'Figure',
    structure: 'Structure',
    unit_cell_kwargs: 'dict[str, Any]',
    is_3d: 'bool' = True,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None
) → Figure

Draw the unit cell of a structure in a 2D or 3D Plotly figure.


function draw_vector

draw_vector(
    fig: 'Figure',
    start: 'ndarray',
    vector: 'ndarray',
    is_3d: 'bool' = False,
    arrow_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → None

Add an arrow to represent a vector quantity on a Plotly figure.

This function adds an arrow to a 2D or 3D Plotly figure to represent a vector quantity. In 3D, it uses a cone for the arrowhead and a line for the shaft. In 2D, it uses a scatter plot with an arrow marker.

Args:

Note:

For 3D arrows, this function adds two traces to the figure: a cone for the arrowhead and a line for the shaft. For 2D arrows, it adds a single scatter trace with an arrow marker.


function get_first_matching_site_prop

get_first_matching_site_prop(
    structures: 'Sequence[Structure]',
    prop_keys: 'Sequence[str]',
    warn_if_none: 'bool' = True,
    filter_callback: 'Callable[[str, Any], bool] | None' = None
) → str | None

Find the first property key that exists in any of the passed structures' properties or site properties. Will look in site.properties first, then structure.properties.

Args:

Returns:


function draw_bonds

draw_bonds(
    fig: 'Figure',
    structure: 'Structure',
    nn: 'NearNeighbors',
    is_3d: 'bool' = True,
    bond_kwargs: 'dict[str, Any] | None' = None,
    row: 'int | None' = None,
    col: 'int | None' = None,
    scene: 'str | None' = None,
    visible_image_atoms: 'set[tuple[float, float, float]] | None' = None
) → None

Draw bonds between atoms in the structure.

module structure_viz

2D and 3D plots of Structures.

module structure_viz.mpl

2D plots of pymatgen structures with matplotlib.

structure_2d() and its helpers get_rot_matrix() and unit_cell_to_lines() were inspired by ASE https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib.

Global Variables


function structure_2d

structure_2d(
    struct: 'Structure | Sequence[Structure]',
    ax: 'Axes | None' = None,
    rotation: 'str' = '10x,8y,3z',
    atomic_radii: 'float | dict[str, float] | None' = None,
    elem_colors: 'ElemColorScheme | dict[str, str | ColorType]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    label_kwargs: 'dict[str, Any] | None' = None,
    bond_kwargs: 'dict[str, Any] | None' = None,
    standardize_struct: 'bool | None' = None,
    axis: 'bool | str' = 'off',
    n_cols: 'int' = 4,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    subplot_title: 'Callable[[Structure, str | int], str] | None' = None
) → Axes | tuple[Figure, ndarray[Axes]]

Plot pymatgen structures in 2D with matplotlib.

structure_2d is not deprecated but structure_(2d|3d)_plotly() have more features and are recommended replacements.

Inspired by ASE's ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib pymatviz aims to give similar output to ASE but supports disordered structures and avoids the conversion hassle of AseAtomsAdaptor.get_atoms(pmg_struct).

For example, these two snippets should give very similar output:

from pymatgen.ext.matproj import MPRester

mp_19017 = MPRester().get_structure_by_material_id("mp-19017")

# ASE
from ase.visualize.plot import plot_atoms
from pymatgen.io.ase import AseAtomsAdaptor

plot_atoms(AseAtomsAdaptor().get_atoms(mp_19017), rotation="10x,8y,3z", radii=0.5)

# pymatviz
from pymatviz import structure_2d

structure_2d(mp_19017)

Multiple structures in single figure example:

from pymatgen.ext.matproj import MPRester
from pymatviz import structure_2d

structures = {
     (mp_id := f"mp-{idx}"): MPRester().get_structure_by_material_id(mp_id)
     for idx in range(1, 5)
}
structure_2d(structures)

Args:

Raises:

Returns:

module structure_viz.plotly

Create interactive hoverable 2D and 3D plots of pymatgen structures with plotly.

Global Variables


function structure_2d_plotly

structure_2d_plotly(
    struct: 'Structure | Sequence[Structure]',
    rotation: 'str' = '10x,8y,3z',
    atomic_radii: 'float | dict[str, float] | None' = None,
    atom_size: 'float' = 40,
    elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool | dict[str, Any]' = True,
    show_sites: 'bool | dict[str, Any]' = True,
    show_image_sites: 'bool | dict[str, Any]' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    standardize_struct: 'bool | None' = None,
    n_cols: 'int' = 3,
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None' = None,
    show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
    vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    bond_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot pymatgen structures in 2D with Plotly.

Args:

Returns:


function structure_3d_plotly

structure_3d_plotly(
    struct: 'Structure | Sequence[Structure]',
    atomic_radii: 'float | dict[str, float] | None' = None,
    atom_size: 'float' = 20,
    elem_colors: 'ElemColorScheme | dict[str, str]' = Jmol,
    scale: 'float' = 1,
    show_unit_cell: 'bool | dict[str, Any]' = True,
    show_sites: 'bool | dict[str, Any]' = True,
    show_image_sites: 'bool' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "Literal['symbol', 'species', False] | dict[str, str] | Sequence[str]" = 'species',
    standardize_struct: 'bool | None' = None,
    n_cols: 'int' = 3,
    subplot_title: 'Callable[[Structure, str | int], str | dict[str, Any]] | None | Literal[False]' = None,
    show_site_vectors: 'str | Sequence[str]' = ('force', 'magmom'),
    vector_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    hover_text: 'SiteCoords | Callable[[PeriodicSite], str]' = Cartesian and Fractional,
    bond_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot pymatgen structures in 3D with Plotly.

Args:

Returns:

module sunburst

Hierarchical multi-level pie charts (i.e. sunbursts).

E.g. for crystal symmetry distributions.

Global Variables


function spacegroup_sunburst

spacegroup_sunburst(
    data: 'Sequence[int | str] | Series',
    show_counts: "Literal['value', 'percent', False]" = False,
    **kwargs: 'Any'
) → Figure

Generate a sunburst plot with crystal systems as the inner ring for a list of international space group numbers.

Hint: To hide very small labels, set a uniformtext minsize and mode='hide'. fig.update_layout(uniformtext=dict(minsize=9, mode="hide"))

Args:

Returns:

module templates

Define custom pymatviz templates (default styles) for plotly and matplotlib.

Global Variables


function set_plotly_template

set_plotly_template(
    template: "Literal['pymatviz_white', 'pymatviz_dark'] | str | Template"
) → None

Set the default plotly express and graph objects template.

Args:

Raises:

module typing

Typing related: TypeAlias, generic types and so on.

Global Variables


function runtime_checkable

runtime_checkable(cls)

Mark a protocol class as a runtime protocol.

Such protocol can be used with isinstance() and issubclass(). Raise TypeError if applied to a non-protocol class. This allows a simple-minded structural check very similar to one trick ponies in collections.abc such as Iterable.

For example:


     @runtime_checkable      class Closable(Protocol):          def close(self): ... 

     assert isinstance(open('/some/file'), Closable) 

Warning: this will check only the presence of the required methods, not their type signatures!


function cast

cast(typ, val)

Cast a value to a type.

This returns the value unchanged. To the type checker this signals that the return value has the designated type, but at runtime we intentionally don't check anything (we want this to be as fast as possible).


function assert_type

assert_type(val, typ)

Ask a static type checker to confirm that the value is of the given type.

At runtime this does nothing: it returns the first argument unchanged with no checks or side effects, no matter the actual type of the argument.

When a static type checker encounters a call to assert_type(), it emits an error if the value is not of the specified type:


     def greet(name: str) -> None:          assert_type(name, str)  # OK          assert_type(name, int)  # type checker error 


---

<a href="https://github.com/janosh/pymatviz/blob/main/typing/get_type_hints#L2319"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>

## <kbd>function</kbd> `get_type_hints`

```python
get_type_hints(obj, globalns=None, localns=None, include_extras=False)

Return type hints for an object.

This is often the same as obj.annotations, but it handles forward references encoded as string literals and recursively replaces all 'Annotated[T, ...]' with 'T' (unless 'include_extras=True').

The argument may be a module, class, method, or function. The annotations are returned as a dictionary. For classes, annotations include also inherited members.

TypeError is raised if the argument is not of a type that can contain annotations, and an empty dictionary is returned if no annotations are present.

BEWARE -- the behavior of globalns and localns is counterintuitive (unless you are familiar with how eval() and exec() work). The search order is locals first, then globals.


function get_origin

get_origin(tp)

Get the unsubscripted version of a type.

This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar, Annotated, and others. Return None for unsupported types.

Examples::

P = ParamSpec('P') >>> assert get_origin(Literal[42]) is Literal >>> assert get_origin(int) is None >>> assert get_origin(ClassVar[int]) is ClassVar >>> assert get_origin(Generic) is Generic >>> assert get_origin(Generic[T]) is Generic >>> assert get_origin(Union[T, int]) is Union >>> assert get_origin(List[Tuple[T, T]][int]) is list >>> assert get_origin(P.args) is P


function get_args

get_args(tp)

Get type arguments with all substitutions performed.

For unions, basic simplifications used by Union constructor are performed.

Examples::

T = TypeVar('T') >>> assert get_args(Dict[str, int]) == (str, int) >>> assert get_args(int) == () >>> assert get_args(Union[int, Union[T, int], str][int]) == (int, str) >>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) >>> assert get_args(Callable[[], T][int]) == ([], int)


function is_typeddict

is_typeddict(tp)

Check if an annotation is a TypedDict class.

For example:


     >>> from typing import TypedDict      >>> class Film(TypedDict):      ...     title: str      ...     year: int      ...      >>> is_typeddict(Film)      True      >>> is_typeddict(dict)      False 


---

<a href="https://github.com/janosh/pymatviz/blob/main/typing/assert_never#L2520"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>

## <kbd>function</kbd> `assert_never`

```python
assert_never(arg: Never) → Never

Statically assert that a line of code is unreachable.

Example::

def int_or_str(arg: int | str) -> None: match arg: case int(): print("It's an int") case str(): print("It's a str") case _: assert_never(arg)

If a type checker finds that a call to assert_never() is reachable, it will emit an error.

At runtime, this throws an exception when called.


function no_type_check

no_type_check(arg)

Decorator to indicate that annotations are not type hints.

The argument must be a class or function; if it is a class, it applies recursively to all methods and classes defined in that class (but not to methods defined in its superclasses or subclasses).

This mutates the function(s) or class(es) in place.


function no_type_check_decorator

no_type_check_decorator(decorator)

Decorator to give another decorator the @no_type_check effect.

This wraps the decorator with something that wraps the decorated function in @no_type_check.


function overload

overload(func)

Decorator for overloaded functions/methods.

In a stub file, place two or more stub definitions for the same function in a row, each decorated with @overload.

For example:


     @overload      def utf8(value: None) -> None: ...      @overload      def utf8(value: bytes) -> bytes: ...      @overload      def utf8(value: str) -> bytes: ... 

In a non-stub file (i.e. a regular .py file), do the same but follow it with an implementation. The implementation should not be decorated with @overload:


     @overload      def utf8(value: None) -> None: ...      @overload      def utf8(value: bytes) -> bytes: ...      @overload      def utf8(value: str) -> bytes: ...      def utf8(value):          ...  # implementation goes here 

The overloads for a function can be retrieved at runtime using the get_overloads() function.


function get_overloads

get_overloads(func)

Return all defined overloads for func as a sequence.


function clear_overloads

clear_overloads()

Clear all overloads in the registry.


function final

final(f)

Decorator to indicate final methods and final classes.

Use this decorator to indicate to type checkers that the decorated method cannot be overridden, and decorated class cannot be subclassed.

For example:


     class Base:          @final          def done(self) -> None:              ...      class Sub(Base):          def done(self) -> None:  # Error reported by type checker              ... 

     @final      class Leaf:          ...      class Other(Leaf):  # Error reported by type checker          ... 

There is no runtime checking of these properties. The decorator attempts to set the __final__ attribute to True on the decorated object to allow runtime introspection.


function NamedTuple

NamedTuple(typename, fields=None, **kwargs)

Typed version of namedtuple.

Usage:


     class Employee(NamedTuple):          name: str          id: int 

This is equivalent to:
 Employee = collections.namedtuple('Employee', ['name', 'id']) 
The resulting class has an extra __annotations__ attribute, giving a dict that maps field names to types.  (The field names are also in the _fields attribute, which is part of the namedtuple API.) An alternative equivalent functional syntax is also accepted:
 Employee = NamedTuple('Employee', [('name', str), ('id', int)]) 

function TypedDict

TypedDict(typename, fields=None, total=True, **kwargs)

A simple typed namespace. At runtime it is equivalent to a plain dict.

TypedDict creates a dictionary type such that a type checker will expect all instances to have a certain set of keys, where each key is associated with a value of a consistent type. This expectation is not checked at runtime.

Usage:


     >>> class Point2D(TypedDict):      ...     x: int      ...     y: int      ...     label: str      ...      >>> a: Point2D = {'x': 1, 'y': 2, 'label': 'good'}  # OK      >>> b: Point2D = {'z': 3, 'label': 'bad'}           # Fails type check      >>> Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')      True 

The type info can be accessed via the Point2D.annotations dict, and the Point2D.required_keys and Point2D.optional_keys frozensets. TypedDict supports an additional equivalent form:


     Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) 

By default, all keys must be present in a TypedDict. It is possible to override this by specifying totality:


     class Point2D(TypedDict, total=False):          x: int          y: int 

This means that a Point2D TypedDict can have any of the keys omitted. A type checker is only expected to support a literal False or True as the value of the total argument. True is the default, and makes all items defined in the class body be required.

The Required and NotRequired special forms can also be used to mark individual keys as being required or not required:


     class Point2D(TypedDict):          x: int               # the "x" key must always be present (Required is the default)          y: NotRequired[int]  # the "y" key can be omitted 

See PEP 655 for more details on Required and NotRequired.


function reveal_type

reveal_type(obj: ~T) → ~T

Ask a static type checker to reveal the inferred type of an expression.

When a static type checker encounters a call to reveal_type(), it will emit the inferred type of the argument:


     x: int = 1      reveal_type(x) 

Running a static type checker (e.g., mypy) on this example will produce output similar to 'Revealed type is "builtins.int"'.

At runtime, the function prints the runtime type of the argument and returns the argument unchanged.


function dataclass_transform

dataclass_transform(
    eq_default: bool = True,
    order_default: bool = False,
    kw_only_default: bool = False,
    field_specifiers: tuple[Union[type[Any], Callable[, Any]], ] = (),
    **kwargs: Any
) → Callable[[~T], ~T]

Decorator to mark an object as providing dataclass-like behaviour.

The decorator can be applied to a function, class, or metaclass.

Example usage with a decorator function:


     T = TypeVar("T") 

     @dataclass_transform()      def create_model(cls: type[T]) -> type[T]:          ...          return cls 

     @create_model      class CustomerModel:          id: int          name: str 

On a base class:
 @dataclass_transform()      class ModelBase: ... 

 class CustomerModel(ModelBase):          id: int          name: str 

On a metaclass:


     @dataclass_transform()      class ModelMeta(type): ... 

     class ModelBase(metaclass=ModelMeta): ... 

     class CustomerModel(ModelBase):          id: int          name: str 

The CustomerModel classes defined above will be treated by type checkers similarly to classes created with @dataclasses.dataclass. For example, type checkers will assume these classes have __init__ methods that accept id and name.

The arguments to this decorator can be used to customize this behavior:

At runtime, this decorator records its arguments in the __dataclass_transform__ attribute on the decorated object. It has no other runtime effect.

See PEP 681 for more details.


class Annotated

Add context-specific metadata to a type.

Example: Annotated[int, runtime_check.Unsigned] indicates to the hypothetical runtime_check module that this type is an unsigned int. Every other consumer of this type can ignore this metadata and treat this type as int.

The first argument to Annotated must be a valid type.

Details:


     assert Annotated[int, '$'].__metadata__ == ('$',) 

- Nested Annotated types are flattened:
 assert Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] 

     assert Annotated[C, Ann1](5) == C(5) 

- Annotated can be used as a generic type alias:
 Optimized: TypeAlias = Annotated[T, runtime.Optimize()]      assert Optimized[int] == Annotated[int, runtime.Optimize()] 

 OptimizedList: TypeAlias = Annotated[list[T], runtime.Optimize()]      assert OptimizedList[int] == Annotated[list[int], runtime.Optimize()] 

     Variadic: TypeAlias = Annotated[*Ts, Ann1]  # NOT valid 

  This would be equivalent to:
 Annotated[T1, T2, T3, ..., Ann1] 

where T1, T2 etc. are TypeVars, which would be invalid, because only one type should be passed to Annotated.


class Any

Special type indicating an unconstrained type.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.


class BinaryIO

Typed version of the return of open() in binary mode.


property closed


property mode


property name


method close

close() → None

method fileno

fileno() → int

method flush

flush() → None

method isatty

isatty() → bool

method read

read(n: int = -1) → ~AnyStr

method readable

readable() → bool

method readline

readline(limit: int = -1) → ~AnyStr

method readlines

readlines(hint: int = -1) → List[~AnyStr]

method seek

seek(offset: int, whence: int = 0) → int

method seekable

seekable() → bool

method tell

tell() → int

method truncate

truncate(size: int = None) → int

method writable

writable() → bool

method write

write(s: Union[bytes, bytearray]) → int

method writelines

writelines(lines: List[~AnyStr]) → None

class ForwardRef

Internal wrapper to hold a forward reference.

method __init__

__init__(arg, is_argument=True, module=None, is_class=False)

class Generic

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as:


   class Mapping(Generic[KT, VT]):        def __getitem__(self, key: KT) -> VT:            ...        # Etc. 

This class can then be used as follows:

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default


class IO

Generic base class for TextIO and BinaryIO.

This is an abstract, generic version of the return of open().

NOTE: This does not distinguish between the different possible classes (text vs. binary, read vs. write vs. read/write, append-only, unbuffered). The TextIO and BinaryIO subclasses below capture the distinctions between text vs. binary, which is pervasive in the interface; however we currently do not offer a way to track the other distinctions in the type system.


property closed


property mode


property name


method close

close() → None

method fileno

fileno() → int

method flush

flush() → None

method isatty

isatty() → bool

method read

read(n: int = -1) → ~AnyStr

method readable

readable() → bool

method readline

readline(limit: int = -1) → ~AnyStr

method readlines

readlines(hint: int = -1) → List[~AnyStr]

method seek

seek(offset: int, whence: int = 0) → int

method seekable

seekable() → bool

method tell

tell() → int

method truncate

truncate(size: int = None) → int

method writable

writable() → bool

method write

write(s: ~AnyStr) → int

method writelines

writelines(lines: List[~AnyStr]) → None

class NamedTupleMeta


class NewType

NewType creates simple unique types with almost zero runtime overhead.

NewType(name, tp) is considered a subtype of tp by static type checkers. At runtime, NewType(name, tp) returns a dummy callable that simply returns its argument.

Usage:


     UserId = NewType('UserId', int) 

     def name_by_id(user_id: UserId) -> str:          ... 

     UserId('user')          # Fails type check 

     name_by_id(42)          # Fails type check      name_by_id(UserId(42))  # OK 

     num = UserId(5) + 1     # type: int 

<a href="https://github.com/janosh/pymatviz/blob/main/typing/__init__#L3206"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>

### <kbd>method</kbd> `__init__`

```python
__init__(name, tp)

class ParamSpec

Parameter specification variable.

Usage:


    P = ParamSpec('P') 

Parameter specification variables exist primarily for the benefit of static type checkers. They are used to forward the parameter types of one callable to another callable, a pattern commonly found in higher order functions and decorators. They are only valid when used in Concatenate, or as the first argument to Callable, or as parameters for user-defined Generics. See class Generic for more information on generic types. An example for annotating a decorator:


    T = TypeVar('T')     P = ParamSpec('P') 

    def add_logging(f: Callable[P, T]) -> Callable[P, T]:         '''A type-safe decorator to add logging to a function.'''         def inner(*args: P.args, **kwargs: P.kwargs) -> T:             logging.info(f'{f.__name__} was called')             return f(*args, **kwargs)         return inner 

    @add_logging     def add_two(x: float, y: float) -> float:         '''Add two numbers together.'''         return x + y 

Parameter specification variables can be introspected. e.g.:

P.name == 'P'

Note that only parameter specification variables defined in global scope can be pickled.

method __init__

__init__(name, bound=None, covariant=False, contravariant=False)

property args


property kwargs


class ParamSpecArgs

The args for a ParamSpec object.

Given a ParamSpec object P, P.args is an instance of ParamSpecArgs.

ParamSpecArgs objects have a reference back to their ParamSpec:

P.args.origin is P

This type is meant for runtime introspection and has no special meaning to static type checkers.

method __init__

__init__(origin)

class ParamSpecKwargs

The kwargs for a ParamSpec object.

Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs.

ParamSpecKwargs objects have a reference back to their ParamSpec:

P.kwargs.origin is P

This type is meant for runtime introspection and has no special meaning to static type checkers.

method __init__

__init__(origin)

class Protocol

Base class for protocol classes.

Protocol classes are defined as:


     class Proto(Protocol):          def meth(self) -> int:              ... 

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:


     class C:          def meth(self) -> int:              return 0 

     def func(x: Proto) -> int:          return x.meth() 

     func(C())  # Passes static type check 

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:


     class GenProto(Protocol[T]):          def meth(self) -> T:              ... 





---

<a href="https://github.com/janosh/pymatviz/blob/main/pymatviz/typing.py"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>

## <kbd>class</kbd> `SupportsAbs`
An ABC with one abstract method __abs__ that is covariant in its return type. 

<a href="https://github.com/janosh/pymatviz/blob/main/typing/_no_init_or_replace_init#L1952"><img align="right" style="float:right;" src="https://img.shields.io/badge/source-blue?style=flat" alt="source link"></a>

### <kbd>function</kbd> `_no_init_or_replace_init`

```python
_no_init_or_replace_init(*args, **kwargs)

class SupportsBytes

An ABC with one abstract method bytes.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class SupportsComplex

An ABC with one abstract method complex.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class SupportsFloat

An ABC with one abstract method float.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class SupportsIndex

An ABC with one abstract method index.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class SupportsInt

An ABC with one abstract method int.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class SupportsRound

An ABC with one abstract method round that is covariant in its return type.

function _no_init_or_replace_init

_no_init_or_replace_init(*args, **kwargs)

class TextIO

Typed version of the return of open() in text mode.


property buffer


property closed


property encoding


property errors


property line_buffering


property mode


property name


property newlines


method close

close() → None

method fileno

fileno() → int

method flush

flush() → None

method isatty

isatty() → bool

method read

read(n: int = -1) → ~AnyStr

method readable

readable() → bool

method readline

readline(limit: int = -1) → ~AnyStr

method readlines

readlines(hint: int = -1) → List[~AnyStr]

method seek

seek(offset: int, whence: int = 0) → int

method seekable

seekable() → bool

method tell

tell() → int

method truncate

truncate(size: int = None) → int

method writable

writable() → bool

method write

write(s: ~AnyStr) → int

method writelines

writelines(lines: List[~AnyStr]) → None

class TypeVar

Type variable.

Usage:


   T = TypeVar('T')  # Can be anything    A = TypeVar('A', str, bytes)  # Must be str or bytes 

Type variables exist primarily for the benefit of static type checkers. They serve as the parameters for generic types as well as for generic function definitions. See class Generic for more information on generic types. Generic functions work as follows:

def repeat(x: T, n: int) -> List[T]: '''Return a list containing n references to x.''' return [x]*n

def longest(x: A, y: A) -> A: '''Return the longest of two strings.''' return x if len(x) >= len(y) else y

The latter example's signature is essentially the overloading of (str, str) -> str and (bytes, bytes) -> bytes. Also note that if the arguments are instances of some subclass of str, the return type is still plain str.

At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError.

Type variables defined with covariant=True or contravariant=True can be used to declare covariant or contravariant generic types. See PEP 484 for more details. By default generic types are invariant in all type variables.

Type variables can be introspected. e.g.:

T.name == 'T' T.constraints == () T.covariant == False T.contravariant = False A.constraints == (str, bytes)

Note that only type variables defined in global scope can be pickled.

method __init__

__init__(name, *constraints, bound=None, covariant=False, contravariant=False)

class TypeVarTuple

Type variable tuple.

Usage:

Ts = TypeVarTuple('Ts') # Can be given any name

Just as a TypeVar (type variable) is a placeholder for a single type, a TypeVarTuple is a placeholder for an arbitrary number of types. For example, if we define a generic class using a TypeVarTuple:

class C(Generic[*Ts]): ...

Then we can parameterize that class with an arbitrary number of type

arguments:

C[int] # Fine C[int, str] # Also fine C[()] # Even this is fine

For more details, see PEP 646.

Note that only TypeVarTuples defined in global scope can be pickled.

method __init__

__init__(name)

class typing.io

Wrapper namespace for IO generic classes.


class typing.re

Wrapper namespace for re type aliases.

module uncertainty

Visualizations for assessing the quality of model uncertainty estimates.

Global Variables


function qq_gaussian

qq_gaussian(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True
) → Axes

Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets y_true and model predictions y_pred.

Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line.

The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution.

Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot

Args:

Returns:


function get_err_decay

get_err_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    n_rand: 'int' = 100
) → tuple[ArrayLike, ArrayLike]

Calculate the model's error curve as samples are excluded from the calculation based on their absolute error.

Use in combination with get_std_decay to see what the error drop curve would look like if model error and uncertainty were perfectly rank-correlated.

Args:

Returns:


function get_std_decay

get_std_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    y_std: 'ArrayLike'
) → ArrayLike

Calculate the drop in model error as samples are excluded from the calculation based on the model's uncertainty.

For model's able to estimate their own uncertainty well, meaning predictions of larger error are associated with larger uncertainty, the error curve should fall off sharply at first as the highest-error points are discarded and slowly towards the end where only small-error samples with little uncertainty remain.

Note that even perfect model uncertainties would not mean this error drop curve coincides exactly with the one returned by get_err_decay as in some cases the model may have made an accurate prediction purely by chance in which case the error is small yet a good uncertainty estimate would still be large, leading the same sample to be excluded at different x-axis locations and thus the get_std_decay curve to lie higher.

Args:

Returns:


function error_decay_with_uncert

error_decay_with_uncert(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    n_rand: 'int' = 100,
    percentiles: 'bool' = True,
    ax: 'Axes | None' = None
) → Axes

Plot for assessing the quality of uncertainty estimates. If a model's uncertainty is well calibrated, i.e. strongly correlated with its error, removing the most uncertain predictions should make the mean error decay similarly to how it decays when removing the predictions of largest error.

Args:

Note: If you're not happy with the default y_max of 1.1 * rand_mean, where rand_mean is mean of random sample exclusion, use ax.set(ylim=[None, some_value * ax.get_ylim()[1]]).

Returns:

module utils.data

Data processing utils: * df_ptable (DataFrame): Periodic table. * atomic_numbers (dict[str, int]): Map elements to atomic numbers. * element_symbols (dict[int, str]): Map atomic numbers to elements.

Global Variables


function bin_df_cols

bin_df_cols(
    df_in: 'DataFrame',
    bin_by_cols: 'Sequence[str]',
    group_by_cols: 'Sequence[str]' = (),
    n_bins: 'int | Sequence[int]' = 100,
    bin_counts_col: 'str' = 'bin_counts',
    density_col: 'str' = '',
    verbose: 'bool' = True
) → DataFrame

Bin columns of a DataFrame.

Args:

Returns:


function crystal_sys_from_spg_num

crystal_sys_from_spg_num(spg: 'float') → CrystalSystem

Get the crystal system for an international space group number.


function df_to_arrays

df_to_arrays(
    df: 'DataFrame | None',
    *args: 'str | Sequence[str] | Sequence[ArrayLike]',
    strict: 'bool' = True
) → list[ArrayLike | dict[str, ArrayLike]]

If df is None, this is a no-op: args are returned as-is. If df is a dataframe, all following args are used as column names and the column data returned as arrays (after dropping rows with NaNs in any column).

Args:

Raises:

Returns:


function html_tag

html_tag(
    text: 'str',
    tag: 'str' = 'span',
    style: 'str' = '',
    title: 'str' = ''
) → str

Wrap text in a span with custom style.

Style defaults to decreased font size and weight e.g. to display units in plotly labels and annotations.

Args:

Returns:


function normalize_to_dict

normalize_to_dict(
    inputs: 'T | Sequence[T] | dict[str, T]',
    cls: 'type[T]' = <class 'pymatgen.core.structure.Structure'>,
    key_gen: 'Callable[[T], str]' = <function <lambda> at 0x7fda102968e0>
) → dict[str, T]

Normalize any kind of object or dict/list/tuple of them into to a dictionary.

Args:

Returns: A dictionary of objects with keys as object formulas or given keys.

Raises:


function patch_dict

patch_dict(
    dct: 'dict[Any, Any]',
    *args: 'Any',
    **kwargs: 'Any'
) → Generator[dict[Any, Any], None, None]

Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.

Useful e.g. for temporary plotly fig.layout mutations:

with patch_dict(fig.layout, showlegend=False): fig.write_image("plot.pdf")

Args:

Yields:


function si_fmt

si_fmt(
    val: 'float',
    fmt: 'str' = '.1f',
    sep: 'str' = '',
    binary: 'bool' = False,
    decimal_threshold: 'float' = 0.01
) → str

Convert large numbers into human readable format using SI prefixes.

Supports binary (1024) and metric (1000) mode.

https://nist.gov/pml/weights-and-measures/metric-si-prefixes

Args:

Returns:

module utils

pymatviz utility functions.

Global Variables


class ExperimentalWarning

Warning for experimental features.

module utils.plotting

Plotting-related utility functions.

Available functions: - annotate: Annotate a matplotlib or plotly figure with text. - apply_matplotlib_template: Set default matplotlib configurations for consistency. - get_cbar_label_formatter: Generate colorbar tick label formatter. - get_font_color: Get the font color used in a Matplotlib or Plotly figure. - get_fig_xy_range: Get the x and y range of a plotly or matplotlib figure. - luminance: Compute the luminance of a color. - pick_bw_for_contrast: Choose black or white text color for contrast. - pretty_label: Map metric keys to their pretty labels. - validate_fig: Decorator to validate the type of fig keyword argument.

Global Variables


function annotate

annotate(text: 'str | Sequence[str]', fig: 'AxOrFig', **kwargs: 'Any') → AxOrFig

Annotate a matplotlib or plotly figure. Supports faceted plots plotly figure with trace with empty strings skipped.

Args:

Returns:

Raises:


function apply_matplotlib_template

apply_matplotlib_template() → None

Set default matplotlib configurations for consistency.


function get_cbar_label_formatter

get_cbar_label_formatter(
    cbar_label_fmt: 'str',
    values_fmt: 'str',
    values_show_mode: "Literal['value', 'fraction', 'percent', 'off']",
    sci_notation: 'bool',
    default_decimal_places: 'int' = 1
) → Formatter

Generate colorbar tick label formatter.

Work differently for different values_show_mode: - "value/fraction" mode: Use cbar_label_fmt (or values_fmt) as is. - "percent" mode: Get number of decimal places to keep from fmt string, for example 1 from ".1%".

Args:

Returns: PercentFormatter or FormatStrFormatter.


function get_font_color

get_font_color(fig: 'AxOrFig') → str

Get the font color used in a Matplotlib figure/axes or a Plotly figure.

Args:

Returns:

Raises:


function luminance

luminance(color: 'str | tuple[float, float, float]') → float

Compute the luminance of a color as in https://stackoverflow.com/a/596243.

Args:

Returns:


function pick_bw_for_contrast

pick_bw_for_contrast(
    color: 'tuple[float, float, float] | str',
    text_color_threshold: 'float' = 0.7
) → Literal['black', 'white']

Choose black or white text color for a given background color based on luminance.

Args:

Returns:


function pretty_label

pretty_label(key: 'str', backend: 'Backend') → str

Map metric keys to their pretty labels.


function validate_fig

validate_fig(func: 'Callable[P, R]') → Callable[P, R]

Decorator to validate the type of fig keyword argument in a function. fig MUST be a keyword argument, not a positional argument.


function get_fig_xy_range

get_fig_xy_range(
    fig: 'Figure | Figure | Axes',
    trace_idx: 'int' = 0
) → tuple[tuple[float, float], tuple[float, float]]

Get the x and y range of a plotly or matplotlib figure.

Args:

Returns:

module utils.testing

Testing related utils.

Global Variables

module xrd

Module for plotting XRD patterns using plotly.

Global Variables


function format_hkl

format_hkl(hkl: 'tuple[int, int, int]', format_type: 'HklFormat') → str

Format hkl indices as a string.

Args:

Raises:


function xrd_pattern

xrd_pattern(
    patterns: 'PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct, dict[str, Any]]]',
    peak_width: 'float' = 0.5,
    annotate_peaks: 'float' = 5,
    hkl_format: 'HklFormat' = 'compact',
    show_angles: 'bool | None' = None,
    wavelength: 'float' = 1.54184,
    stack: "Literal['horizontal', 'vertical'] | None" = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    subtitle_kwargs: 'dict[str, Any] | None' = None
) → Figure

Create a plotly figure of XRD patterns from a pymatgen DiffractionPattern, from a pymatgen Structure, or a dictionary of either of them.

Args: patterns (PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct,

Raises:

Returns: