« home
"""Locally run atomate2 PhononMaker with CHGNet, M3GNet, and MACE on MP structures."""

# %%
import os
import shutil
from collections import defaultdict
from time import perf_counter
from zipfile import BadZipFile

import atomate2.forcefields.jobs as ff_jobs
import plotly.express as px
import torch
from atomate2.common.schemas.phonons import PhononBSDOSDoc as Atomate2PhononBSDOSDoc
from atomate2.forcefields.flows.phonons import PhononMaker
from jobflow import run_locally
from mp_api.client import MPRester
from tqdm import tqdm

import pymatviz as pmv
from pymatviz.enums import Key, Model
from pymatviz.phonons import PhononDBDoc


__author__ = "Janosh Riebesell"
__date__ = "2023-11-19"

px.defaults.template = "pymatviz_white"
# %%

ROOT = os.path.dirname(__file__)

RUNS_DIR = f"{ROOT}/tmp/runs"  # noqa: S108
shutil.rmtree(RUNS_DIR, ignore_errors=True)  # remove old runs to save space
for directory in (RUNS_DIR,):
    os.makedirs(directory, exist_ok=True)

common_relax_kwargs = dict(fmax=0.00001)
mace_kwargs = dict(model="medium")
chgnet_kwargs = dict(optimizer_kwargs=dict(use_device="mps"), assign_magmoms=False)

do_mlff_relax = True  # whether to MLFF-relax the PBE structure
models = {
    str(Model.mace_mp): dict(
        bulk_relax_maker=ff_jobs.MACERelaxMaker(
            relax_kwargs=common_relax_kwargs,
            calculator_kwargs={"default_dtype": "float64"} | mace_kwargs,
        )
        if do_mlff_relax
        else None,
        phonon_displacement_maker=ff_jobs.MACEStaticMaker(
            calculator_kwargs=mace_kwargs
        ),
        static_energy_maker=ff_jobs.MACEStaticMaker(calculator_kwargs=mace_kwargs),
    ),
    str(Model.m3gnet_ms): dict(
        bulk_relax_maker=ff_jobs.M3GNetRelaxMaker(relax_kwargs=common_relax_kwargs)
        if do_mlff_relax
        else None,
        phonon_displacement_maker=ff_jobs.M3GNetStaticMaker(),
        static_energy_maker=ff_jobs.M3GNetStaticMaker(),
    ),
    str(Model.chgnet_030): dict(
        bulk_relax_maker=ff_jobs.CHGNetRelaxMaker(
            relax_kwargs=common_relax_kwargs, calculator_kwargs=chgnet_kwargs
        )
        if do_mlff_relax
        else None,
        phonon_displacement_maker=ff_jobs.CHGNetStaticMaker(
            calculator_kwargs=chgnet_kwargs
        ),
        static_energy_maker=ff_jobs.CHGNetStaticMaker(calculator_kwargs=chgnet_kwargs),
    ),
}
# %% fetch MP structure

mp_id = "mp-1234"  # pick your favorite MP material
with MPRester() as mpr:
    struct_mp = mpr.materials.get_structure_by_material_id(mp_id)
struct_mp.properties[Key.mat_id] = mp_id
Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]
# %% Main loop over materials and models

errors: list[tuple[str, str, str]] = []
last_error: Exception = None
results: dict[str, dict[str, PhononDBDoc]] = defaultdict(dict)


for struct in (pbar := tqdm([struct_mp])):  # PhononDB
    formula = struct.formula.replace(" ", "")
    mat_id = struct.properties[Key.mat_id]

    for model, mlff_makers in models.items():
        model_key = str(model).lower().replace(" ", "-")
        os.makedirs(root_dir := f"{RUNS_DIR}/{model_key}", exist_ok=True)

        try:
            start = perf_counter()
            phonon_flow = PhononMaker(
                **mlff_makers,
                store_force_constants=False,
                # "setyawan_curtarolo" only compatible with primitive cell!
                # use "seekpath" with non-primitive cells
                kpath_scheme="seekpath",
                create_thermal_displacements=False,
                # use_symmetrized_structure="primitive",
            ).make(structure=struct)
            # ensure supercell is at least 6 Angstrom in each direction

            # phonon_flow.draw_graph().show()

            result = run_locally(
                phonon_flow, root_dir=root_dir, log=False, ensure_success=True
            )
            print(f"\n{model} took: {perf_counter() - start:.2f} s")

            last_job_id = phonon_flow[-1].uuid
            ml_phonon_doc: Atomate2PhononBSDOSDoc = result[last_job_id][1].output

            ml_bs, ml_dos = ml_phonon_doc.phonon_bandstructure, ml_phonon_doc.phonon_dos

            results[mat_id][model_key] = ml_phonon_doc
        except (ValueError, RuntimeError, BadZipFile, Exception) as exc:
            # known possible errors:
            # - the 2 band structures are not compatible, due to symmetry change during
            # MACE relaxation, try different PhononMaker symprec (default=1e-4). compare
            # PBE and MACE space groups to verify cause
            # - phonopy found imaginary dispersion > 1e-10 (fixed by disabling thermal
            # displacement matrices)
            # - phonopy-internal: RuntimeError: Creating primitive cell failed.
            # PRIMITIVE_AXIS may be incorrectly specified. For mp-754196 Ba2Sr1I6
            # faulty downloads of phonondb docs raise "BadZipFile: is not a zip file"
            # - mp-984055 raised: [1] 51628 segmentation fault
            # multiprocessing/resource_tracker.py:254: UserWarning: There appear to be 1
            # leaked semaphore objects to clean up at shutdown
            last_error = exc
            errors += [(mat_id, str(model), formula, exc)]

        # MACE changes torch default dtype to float64 which breaks CHGNet and M3GNet
        torch.set_default_dtype(torch.float32)  # so reset it here

if errors:
    print(f"\n{errors=}")
  0%|          | 0/1 [00:00<?, ?it/s]
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
/Users/janosh/dev/atomate2/src/atomate2/common/jobs/phonons.py:140: UserWarning: Initial magnetic moments will not be considered for the determination of the symmetry of the structure and thus will be removed now.
  warnings.warn(
Warning: Point group symmetries of supercell and primitivecell are different.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Using Materials Project MACE for MACECalculator with /Users/janosh/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.
Warning: Point group symmetries of supercell and primitivecell are different.
WARNING:matplotlib.backends.backend_ps:The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
mace-y7uhwpje took: 82.10 s
/Users/janosh/.venv/py311/lib/python3.11/site-packages/matgl/apps/pes.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.element_refs = AtomRef(property_offset=torch.tensor(element_refs, dtype=matgl.float_th))
/Users/janosh/.venv/py311/lib/python3.11/site-packages/matgl/apps/pes.py:75: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer("data_mean", torch.tensor(data_mean, dtype=matgl.float_th))
/Users/janosh/.venv/py311/lib/python3.11/site-packages/matgl/apps/pes.py:76: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer("data_std", torch.tensor(data_std, dtype=matgl.float_th))
/Users/janosh/.venv/py311/lib/python3.11/site-packages/matgl/layers/_basis.py:121: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  root = torch.tensor(roots[i])
/Users/janosh/dev/atomate2/src/atomate2/common/jobs/phonons.py:140: UserWarning: Initial magnetic moments will not be considered for the determination of the symmetry of the structure and thus will be removed now.
  warnings.warn(
Warning: Point group symmetries of supercell and primitivecell are different.
Warning: Point group symmetries of supercell and primitivecell are different.
WARNING:matplotlib.backends.backend_ps:The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
m3gnet took: 48.20 s
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
/Users/janosh/dev/atomate2/src/atomate2/common/jobs/phonons.py:140: UserWarning: Initial magnetic moments will not be considered for the determination of the symmetry of the structure and thus will be removed now.
  warnings.warn(
Warning: Point group symmetries of supercell and primitivecell are different.
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on mps
Warning: Point group symmetries of supercell and primitivecell are different.
WARNING:matplotlib.backends.backend_ps:The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
100%|██████████| 1/1 [02:50<00:00, 170.08s/it]
chgnet-v0.3.0 took: 39.77 s

# %% plot all results into one figure
bs_dict = {
    model_key: doc.phonon_bandstructure for model_key, doc in results[mp_id].items()
}
dos_dict = {model_key: doc.phonon_dos for model_key, doc in results[mp_id].items()}

fig_bs_dos = pmv.phonon_bands_and_dos(bs_dict, dos_dict)
fig_bs_dos.layout.title.update(
    text=f"Phonon Bands and DOS for {mp_id} {struct_mp.formula}", x=0.5, y=0.98
)
fig_bs_dos.layout.margin.update(l=0, r=0, b=0, t=30)
fig_bs_dos.show()