mirror of https://github.com/phonopy/phono3py.git
Merge pull request #420 from phonopy/refactor
Merge CUI for pypolymlp to phonopy
This commit is contained in:
commit
ff95c4dd2e
|
@ -40,9 +40,13 @@ import copy
|
|||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Literal
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
from phonopy import Phonopy
|
||||
from phonopy.cui.load_helper import (
|
||||
develop_or_load_pypolymlp as develop_or_load_pypolymlp_phonopy,
|
||||
)
|
||||
from phonopy.file_IO import get_dataset_type2
|
||||
from phonopy.interface.calculator import get_calculator_physical_units
|
||||
from phonopy.interface.pypolymlp import PypolymlpParams, parse_mlp_params
|
||||
|
@ -185,66 +189,19 @@ def _read_FORCES_FC3_or_FC2(
|
|||
return dataset
|
||||
|
||||
|
||||
def develop_pypolymlp(
|
||||
def develop_or_load_pypolymlp(
|
||||
ph3py: Phono3py,
|
||||
mlp_params: str | dict | PypolymlpParams | None = None,
|
||||
mlp_filename: str | os.PathLike | None = None,
|
||||
log_level: int = 0,
|
||||
):
|
||||
"""Run pypolymlp to compute forces."""
|
||||
if log_level:
|
||||
import pypolymlp
|
||||
|
||||
print("-" * 29 + " pypolymlp start " + "-" * 30)
|
||||
print("Pypolymlp version", pypolymlp.__version__)
|
||||
print("Pypolymlp is a generator of polynomial machine learning potentials.")
|
||||
print("Please cite the paper: A. Seko, J. Appl. Phys. 133, 011101 (2023).")
|
||||
print("Pypolymlp is developed at https://github.com/sekocha/pypolymlp.")
|
||||
|
||||
mlp_loaded = False
|
||||
for mlp_filename in ["polymlp.yaml", "phono3py.pmlp"]:
|
||||
_mlp_filename_list = list(pathlib.Path().glob(f"{mlp_filename}*"))
|
||||
if _mlp_filename_list:
|
||||
_mlp_filename = _mlp_filename_list[0]
|
||||
if _mlp_filename.suffix not in [
|
||||
".yaml",
|
||||
".pmlp",
|
||||
".xz",
|
||||
".gz",
|
||||
".bz2",
|
||||
".lzma",
|
||||
]:
|
||||
continue
|
||||
if log_level:
|
||||
print(f'Load MLPs from "{_mlp_filename}".')
|
||||
ph3py.load_mlp(_mlp_filename)
|
||||
mlp_loaded = True
|
||||
if log_level and mlp_filename == "phono3py.pmlp":
|
||||
print(f'Loading MLPs from "{_mlp_filename}" is obsolete.')
|
||||
break
|
||||
|
||||
mlp_filename = "polymlp.yaml"
|
||||
if not mlp_loaded:
|
||||
if forces_in_dataset(ph3py.mlp_dataset):
|
||||
if log_level:
|
||||
if mlp_params is None:
|
||||
pmlp_params = PypolymlpParams()
|
||||
else:
|
||||
pmlp_params = parse_mlp_params(mlp_params)
|
||||
print("Parameters:")
|
||||
for k, v in dataclasses.asdict(pmlp_params).items():
|
||||
if v is not None:
|
||||
print(f" {k}: {v}")
|
||||
print("Developing MLPs by pypolymlp...", flush=True)
|
||||
ph3py.develop_mlp(params=mlp_params)
|
||||
ph3py.save_mlp(filename=mlp_filename)
|
||||
if log_level:
|
||||
print(f'MLPs were written into "{mlp_filename}"', flush=True)
|
||||
else:
|
||||
raise RuntimeError(f'"{mlp_filename}" is not found.')
|
||||
|
||||
if log_level:
|
||||
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
|
||||
develop_or_load_pypolymlp_phonopy(
|
||||
cast(Phonopy, ph3py),
|
||||
mlp_params=mlp_params,
|
||||
mlp_filename=mlp_filename,
|
||||
log_level=log_level,
|
||||
)
|
||||
|
||||
|
||||
def generate_displacements_and_evaluate_pypolymlp(
|
||||
|
|
|
@ -51,7 +51,7 @@ from phonopy.structure.cells import determinant
|
|||
|
||||
from phono3py import Phono3py
|
||||
from phono3py.cui.create_force_constants import (
|
||||
develop_pypolymlp,
|
||||
develop_or_load_pypolymlp,
|
||||
parse_forces,
|
||||
)
|
||||
from phono3py.file_IO import read_fc2_from_hdf5, read_fc3_from_hdf5
|
||||
|
@ -374,7 +374,7 @@ def load(
|
|||
|
||||
if produce_fc:
|
||||
if ph3py.fc3 is None and use_pypolymlp:
|
||||
develop_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level)
|
||||
develop_or_load_pypolymlp(ph3py, mlp_params=mlp_params, log_level=log_level)
|
||||
|
||||
compute_force_constants_from_datasets(
|
||||
ph3py,
|
||||
|
|
|
@ -60,6 +60,8 @@ from phonopy.cui.settings import PhonopySettings
|
|||
from phonopy.exception import (
|
||||
CellNotFoundError,
|
||||
ForceCalculatorRequiredError,
|
||||
PypolymlpDevelopmentError,
|
||||
PypolymlpFileNotFoundError,
|
||||
PypolymlpRelaxationError,
|
||||
)
|
||||
from phonopy.file_IO import is_file_phonopy_yaml
|
||||
|
@ -75,7 +77,7 @@ from phonopy.structure.cells import isclose as cells_isclose
|
|||
|
||||
from phono3py import Phono3py, Phono3pyIsotope, Phono3pyJointDos
|
||||
from phono3py.cui.create_force_constants import (
|
||||
develop_pypolymlp,
|
||||
develop_or_load_pypolymlp,
|
||||
generate_displacements_and_evaluate_pypolymlp,
|
||||
)
|
||||
from phono3py.cui.create_force_sets import (
|
||||
|
@ -599,11 +601,17 @@ def _run_pypolymlp(
|
|||
ph3py.mlp_dataset = ph3py.dataset
|
||||
ph3py.dataset = None
|
||||
|
||||
develop_pypolymlp(
|
||||
ph3py,
|
||||
mlp_params=settings.mlp_params,
|
||||
log_level=log_level,
|
||||
)
|
||||
try:
|
||||
develop_or_load_pypolymlp(
|
||||
ph3py,
|
||||
mlp_params=settings.mlp_params,
|
||||
log_level=log_level,
|
||||
)
|
||||
except (PypolymlpDevelopmentError, PypolymlpFileNotFoundError) as e:
|
||||
print_error_message(str(e))
|
||||
if log_level:
|
||||
print_error()
|
||||
sys.exit(1)
|
||||
|
||||
_ph3py = ph3py
|
||||
if settings.relax_atomic_positions:
|
||||
|
|
Loading…
Reference in New Issue