Merge CUI for pypolymlp to phonopy

This commit is contained in:
Atsushi Togo 2025-07-29 16:11:59 +09:00
parent de2e6d3025
commit 3ec495f05a
3 changed files with 28 additions and 63 deletions

View File

@ -40,9 +40,13 @@ import copy
import dataclasses import dataclasses
import os import os
import pathlib import pathlib
from typing import Literal from typing import Literal, cast
import numpy as np import numpy as np
from phonopy import Phonopy
from phonopy.cui.load_helper import (
develop_or_load_pypolymlp as develop_or_load_pypolymlp_phonopy,
)
from phonopy.file_IO import get_dataset_type2 from phonopy.file_IO import get_dataset_type2
from phonopy.interface.calculator import get_calculator_physical_units from phonopy.interface.calculator import get_calculator_physical_units
from phonopy.interface.pypolymlp import PypolymlpParams, parse_mlp_params from phonopy.interface.pypolymlp import PypolymlpParams, parse_mlp_params
@ -185,66 +189,19 @@ def _read_FORCES_FC3_or_FC2(
return dataset return dataset
def develop_pypolymlp( def develop_or_load_pypolymlp(
ph3py: Phono3py, ph3py: Phono3py,
mlp_params: str | dict | PypolymlpParams | None = None, mlp_params: str | dict | PypolymlpParams | None = None,
mlp_filename: str | os.PathLike | None = None, mlp_filename: str | os.PathLike | None = None,
log_level: int = 0, log_level: int = 0,
): ):
"""Run pypolymlp to compute forces.""" """Run pypolymlp to compute forces."""
if log_level: develop_or_load_pypolymlp_phonopy(
import pypolymlp cast(Phonopy, ph3py),
mlp_params=mlp_params,
print("-" * 29 + " pypolymlp start " + "-" * 30) mlp_filename=mlp_filename,
print("Pypolymlp version", pypolymlp.__version__) log_level=log_level,
print("Pypolymlp is a generator of polynomial machine learning potentials.") )
print("Please cite the paper: A. Seko, J. Appl. Phys. 133, 011101 (2023).")
print("Pypolymlp is developed at https://github.com/sekocha/pypolymlp.")
mlp_loaded = False
for mlp_filename in ["polymlp.yaml", "phono3py.pmlp"]:
_mlp_filename_list = list(pathlib.Path().glob(f"{mlp_filename}*"))
if _mlp_filename_list:
_mlp_filename = _mlp_filename_list[0]
if _mlp_filename.suffix not in [
".yaml",
".pmlp",
".xz",
".gz",
".bz2",
".lzma",
]:
continue
if log_level:
print(f'Load MLPs from "{_mlp_filename}".')
ph3py.load_mlp(_mlp_filename)
mlp_loaded = True
if log_level and mlp_filename == "phono3py.pmlp":
print(f'Loading MLPs from "{_mlp_filename}" is obsolete.')
break
mlp_filename = "polymlp.yaml"
if not mlp_loaded:
if forces_in_dataset(ph3py.mlp_dataset):
if log_level:
if mlp_params is None:
pmlp_params = PypolymlpParams()
else:
pmlp_params = parse_mlp_params(mlp_params)
print("Parameters:")
for k, v in dataclasses.asdict(pmlp_params).items():
if v is not None:
print(f" {k}: {v}")
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
ph3py.save_mlp(filename=mlp_filename)
if log_level:
print(f'MLPs were written into "{mlp_filename}"', flush=True)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')
if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
def generate_displacements_and_evaluate_pypolymlp( def generate_displacements_and_evaluate_pypolymlp(

View File

@ -51,7 +51,7 @@ from phonopy.structure.cells import determinant
from phono3py import Phono3py from phono3py import Phono3py
from phono3py.cui.create_force_constants import ( from phono3py.cui.create_force_constants import (
develop_pypolymlp, develop_or_load_pypolymlp,
parse_forces, parse_forces,
) )
from phono3py.file_IO import read_fc2_from_hdf5, read_fc3_from_hdf5 from phono3py.file_IO import read_fc2_from_hdf5, read_fc3_from_hdf5
@ -374,7 +374,7 @@ def load(
if produce_fc: if produce_fc:
if ph3py.fc3 is None and use_pypolymlp: if ph3py.fc3 is None and use_pypolymlp:
develop_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level) develop_or_load_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level)
compute_force_constants_from_datasets( compute_force_constants_from_datasets(
ph3py, ph3py,

View File

@ -60,6 +60,8 @@ from phonopy.cui.settings import PhonopySettings
from phonopy.exception import ( from phonopy.exception import (
CellNotFoundError, CellNotFoundError,
ForceCalculatorRequiredError, ForceCalculatorRequiredError,
PypolymlpDevelopmentError,
PypolymlpFileNotFoundError,
PypolymlpRelaxationError, PypolymlpRelaxationError,
) )
from phonopy.file_IO import is_file_phonopy_yaml from phonopy.file_IO import is_file_phonopy_yaml
@ -75,7 +77,7 @@ from phonopy.structure.cells import isclose as cells_isclose
from phono3py import Phono3py, Phono3pyIsotope, Phono3pyJointDos from phono3py import Phono3py, Phono3pyIsotope, Phono3pyJointDos
from phono3py.cui.create_force_constants import ( from phono3py.cui.create_force_constants import (
develop_pypolymlp, develop_or_load_pypolymlp,
generate_displacements_and_evaluate_pypolymlp, generate_displacements_and_evaluate_pypolymlp,
) )
from phono3py.cui.create_force_sets import ( from phono3py.cui.create_force_sets import (
@ -599,11 +601,17 @@ def _run_pypolymlp(
ph3py.mlp_dataset = ph3py.dataset ph3py.mlp_dataset = ph3py.dataset
ph3py.dataset = None ph3py.dataset = None
develop_pypolymlp( try:
ph3py, develop_or_load_pypolymlp(
mlp_params=settings.mlp_params, ph3py,
log_level=log_level, mlp_params=settings.mlp_params,
) log_level=log_level,
)
except (PypolymlpDevelopmentError, PypolymlpFileNotFoundError) as e:
print_error_message(str(e))
if log_level:
print_error()
sys.exit(1)
_ph3py = ph3py _ph3py = ph3py
if settings.relax_atomic_positions: if settings.relax_atomic_positions: