Merge pull request #420 from phonopy/refactor

Merge CUI for pypolymlp to phonopy
This commit is contained in:
Atsushi Togo 2025-07-29 16:22:34 +09:00 committed by GitHub
commit ff95c4dd2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 63 deletions

View File

@ -40,9 +40,13 @@ import copy
import dataclasses
import os
import pathlib
from typing import Literal
from typing import Literal, cast
import numpy as np
from phonopy import Phonopy
from phonopy.cui.load_helper import (
develop_or_load_pypolymlp as develop_or_load_pypolymlp_phonopy,
)
from phonopy.file_IO import get_dataset_type2
from phonopy.interface.calculator import get_calculator_physical_units
from phonopy.interface.pypolymlp import PypolymlpParams, parse_mlp_params
@ -185,66 +189,19 @@ def _read_FORCES_FC3_or_FC2(
return dataset
def develop_pypolymlp(
def develop_or_load_pypolymlp(
ph3py: Phono3py,
mlp_params: str | dict | PypolymlpParams | None = None,
mlp_filename: str | os.PathLike | None = None,
log_level: int = 0,
):
"""Run pypolymlp to compute forces."""
if log_level:
import pypolymlp
print("-" * 29 + " pypolymlp start " + "-" * 30)
print("Pypolymlp version", pypolymlp.__version__)
print("Pypolymlp is a generator of polynomial machine learning potentials.")
print("Please cite the paper: A. Seko, J. Appl. Phys. 133, 011101 (2023).")
print("Pypolymlp is developed at https://github.com/sekocha/pypolymlp.")
mlp_loaded = False
for mlp_filename in ["polymlp.yaml", "phono3py.pmlp"]:
_mlp_filename_list = list(pathlib.Path().glob(f"{mlp_filename}*"))
if _mlp_filename_list:
_mlp_filename = _mlp_filename_list[0]
if _mlp_filename.suffix not in [
".yaml",
".pmlp",
".xz",
".gz",
".bz2",
".lzma",
]:
continue
if log_level:
print(f'Load MLPs from "{_mlp_filename}".')
ph3py.load_mlp(_mlp_filename)
mlp_loaded = True
if log_level and mlp_filename == "phono3py.pmlp":
print(f'Loading MLPs from "{_mlp_filename}" is obsolete.')
break
mlp_filename = "polymlp.yaml"
if not mlp_loaded:
if forces_in_dataset(ph3py.mlp_dataset):
if log_level:
if mlp_params is None:
pmlp_params = PypolymlpParams()
else:
pmlp_params = parse_mlp_params(mlp_params)
print("Parameters:")
for k, v in dataclasses.asdict(pmlp_params).items():
if v is not None:
print(f" {k}: {v}")
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
ph3py.save_mlp(filename=mlp_filename)
if log_level:
print(f'MLPs were written into "{mlp_filename}"', flush=True)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')
if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
develop_or_load_pypolymlp_phonopy(
cast(Phonopy, ph3py),
mlp_params=mlp_params,
mlp_filename=mlp_filename,
log_level=log_level,
)
def generate_displacements_and_evaluate_pypolymlp(

View File

@ -51,7 +51,7 @@ from phonopy.structure.cells import determinant
from phono3py import Phono3py
from phono3py.cui.create_force_constants import (
develop_pypolymlp,
develop_or_load_pypolymlp,
parse_forces,
)
from phono3py.file_IO import read_fc2_from_hdf5, read_fc3_from_hdf5
@ -374,7 +374,7 @@ def load(
if produce_fc:
if ph3py.fc3 is None and use_pypolymlp:
develop_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level)
develop_or_load_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level)
compute_force_constants_from_datasets(
ph3py,

View File

@ -60,6 +60,8 @@ from phonopy.cui.settings import PhonopySettings
from phonopy.exception import (
CellNotFoundError,
ForceCalculatorRequiredError,
PypolymlpDevelopmentError,
PypolymlpFileNotFoundError,
PypolymlpRelaxationError,
)
from phonopy.file_IO import is_file_phonopy_yaml
@ -75,7 +77,7 @@ from phonopy.structure.cells import isclose as cells_isclose
from phono3py import Phono3py, Phono3pyIsotope, Phono3pyJointDos
from phono3py.cui.create_force_constants import (
develop_pypolymlp,
develop_or_load_pypolymlp,
generate_displacements_and_evaluate_pypolymlp,
)
from phono3py.cui.create_force_sets import (
@ -599,11 +601,17 @@ def _run_pypolymlp(
ph3py.mlp_dataset = ph3py.dataset
ph3py.dataset = None
develop_pypolymlp(
ph3py,
mlp_params=settings.mlp_params,
log_level=log_level,
)
try:
develop_or_load_pypolymlp(
ph3py,
mlp_params=settings.mlp_params,
log_level=log_level,
)
except (PypolymlpDevelopmentError, PypolymlpFileNotFoundError) as e:
print_error_message(str(e))
if log_level:
print_error()
sys.exit(1)
_ph3py = ph3py
if settings.relax_atomic_positions: