Update for pypolymlp v0.9

This commit is contained in:
Atsushi Togo 2025-03-03 10:37:37 +09:00
parent 7a2b54faa7
commit 33827f7210
3 changed files with 48 additions and 27 deletions

View File

@ -525,7 +525,7 @@ def run_pypolymlp_to_compute_forces(
random_seed: Optional[int] = None,
prepare_dataset: bool = False,
cutoff_pair_distance: Optional[float] = None,
mlp_filename: str = "phono3py.pmlp",
mlp_filename: Optional[str] = None,
log_level: int = 0,
):
"""Run pypolymlp to compute forces."""
@ -535,27 +535,48 @@ def run_pypolymlp_to_compute_forces(
print("Please cite the paper: A. Seko, J. Appl. Phys. 133, 011101 (2023).")
print("Pypolymlp is developed at https://github.com/sekocha/pypolymlp.")
if pathlib.Path(mlp_filename).exists():
if log_level:
print(f'Load MLPs from "{mlp_filename}".')
ph3py.load_mlp(mlp_filename)
elif forces_in_dataset(ph3py.mlp_dataset):
if log_level:
if mlp_params is None:
pmlp_params = PypolymlpParams()
else:
pmlp_params = parse_mlp_params(mlp_params)
print("Parameters:")
for k, v in asdict(pmlp_params).items():
if v is not None:
print(f" {k}: {v}")
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
ph3py.save_mlp(filename=mlp_filename)
if log_level:
print(f'MLPs were written into "{mlp_filename}"', flush=True)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')
mlp_loaded = False
for mlp_filename in ["polymlp.yaml", "phono3py.pmlp"]:
_mlp_filename_list = list(pathlib.Path().glob(f"{mlp_filename}*"))
if _mlp_filename_list:
_mlp_filename = _mlp_filename_list[0]
print(_mlp_filename)
if _mlp_filename.suffix not in [
".yaml",
".pmlp",
".xz",
".gz",
".bz2",
"lzma",
]:
continue
if log_level:
print(f'Load MLPs from "{mlp_filename}".')
ph3py.load_mlp(mlp_filename)
mlp_loaded = True
if log_level and mlp_filename == "phono3py.pmlp":
print(f'Loading MLPs from "{_mlp_filename}" is obsolete.')
break
mlp_filename = "polymlp.yaml"
if not mlp_loaded:
if forces_in_dataset(ph3py.mlp_dataset):
if log_level:
if mlp_params is None:
pmlp_params = PypolymlpParams()
else:
pmlp_params = parse_mlp_params(mlp_params)
print("Parameters:")
for k, v in asdict(pmlp_params).items():
if v is not None:
print(f" {k}: {v}")
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
ph3py.save_mlp(filename=mlp_filename)
if log_level:
print(f'MLPs were written into "{mlp_filename}"', flush=True)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')
if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)

View File

@ -1200,7 +1200,7 @@ def main(**argparse_control):
)
ph3py.save(mlp_eval_filename)
# pypolymlp dataset is stored in "phono3py.pmlp" and stop here.
# pypolymlp dataset is stored in "polymlp.yaml" and stop here.
if not prepare_dataset:
if log_level:
print(

View File

@ -122,11 +122,11 @@ def test_phono3py_with_QE_calculator(load_phono3py_yaml):
def test_phono3py_load_with_pypolymlp_si():
"""Test phono3py-load script with pypolymlp.
First run generates phono3py.pmlp.
Second run uses phono3py.pmlp.
First run generates polymlp.yaml.
Second run uses polymlp.yaml.
"""
pytest.importorskip("pypolymlp")
pytest.importorskip("pypolymlp", minversion="0.9.2")
pytest.importorskip("symfc")
argparse_control = _get_phono3py_load_args(
@ -159,7 +159,7 @@ def test_phono3py_load_with_pypolymlp_si():
"phono3py.yaml",
"fc2.hdf5",
"fc3.hdf5",
"phono3py.pmlp",
"polymlp.yaml",
"phono3py_mlp_eval_dataset.yaml",
):
file_path = pathlib.Path(cwd_called / created_filename)