Refactor settings.py following that of phonopy

This commit is contained in:
Atsushi Togo 2025-06-17 20:03:31 +09:00
parent 36ae0c1ffb
commit 1441cbc155
5 changed files with 391 additions and 640 deletions

View File

@ -39,7 +39,7 @@ import sys
from phonopy.cui.phonopy_argparse import fix_deprecated_option_names
def get_parser(fc_symmetry=False, is_nac=False, load_phono3py_yaml=False):
def get_parser(load_phono3py_yaml: bool = False):
"""Return ArgumentParser instance."""
deprecated = fix_deprecated_option_names(sys.argv)
import argparse
@ -305,7 +305,7 @@ def get_parser(fc_symmetry=False, is_nac=False, load_phono3py_yaml=False):
"string with the style of key = values"
),
)
if not fc_symmetry:
if not load_phono3py_yaml:
parser.add_argument(
"--fc-symmetry",
"--sym-fc",
@ -487,7 +487,7 @@ def get_parser(fc_symmetry=False, is_nac=False, load_phono3py_yaml=False):
default=None,
help="Mass variance parameters for isotope scattering",
)
if not is_nac:
if not load_phono3py_yaml:
parser.add_argument(
"--nac",
dest="is_nac",
@ -501,7 +501,7 @@ def get_parser(fc_symmetry=False, is_nac=False, load_phono3py_yaml=False):
default=None,
help="Non-analytical term correction method: Gonze (default) or Wang",
)
if fc_symmetry:
if load_phono3py_yaml:
parser.add_argument(
"--no-fc-symmetry",
"--no-sym-fc",
@ -532,7 +532,7 @@ def get_parser(fc_symmetry=False, is_nac=False, load_phono3py_yaml=False):
default=False,
help="No symmetrization of triplets is made.",
)
if is_nac:
if load_phono3py_yaml:
parser.add_argument(
"--nonac",
dest="is_nac",

View File

@ -43,6 +43,7 @@ import warnings
from typing import Optional
import numpy as np
from numpy.typing import NDArray
from phonopy.cui.collect_cell_info import collect_cell_info
from phonopy.cui.phonopy_argparse import show_deprecated_option_warnings
from phonopy.cui.phonopy_script import (
@ -97,7 +98,7 @@ from phono3py.file_IO import (
)
from phono3py.interface.fc_calculator import determine_cutoff_pair_distance
from phono3py.interface.phono3py_yaml import Phono3pyYaml
from phono3py.phonon.grid import get_grid_point_from_address, get_ir_grid_points
from phono3py.phonon.grid import BZGrid, get_grid_point_from_address, get_ir_grid_points
from phono3py.phonon3.dataset import forces_in_dataset
from phono3py.phonon3.fc3 import show_drift_fc3
from phono3py.phonon3.gruneisen import run_gruneisen_parameters
@ -131,10 +132,10 @@ def print_end_phono3py():
def finalize_phono3py(
phono3py: Phono3py,
confs_dict,
log_level,
write_displacements=False,
filename=None,
confs_dict: dict,
log_level: int,
write_displacements: bool = False,
filename: str | None = None,
):
"""Write phono3py.yaml and then exit.
@ -183,7 +184,7 @@ def finalize_phono3py(
sys.exit(0)
def get_run_mode(settings):
def get_run_mode(settings: Phono3pySettings):
"""Extract run mode from settings."""
run_mode = None
if settings.is_gruneisen:
@ -213,7 +214,7 @@ def get_run_mode(settings):
def start_phono3py(**argparse_control) -> tuple[argparse.Namespace, int]:
"""Parse arguments and set some basic parameters."""
parser, deprecated = get_parser(**argparse_control)
parser, deprecated = get_parser(argparse_control.get("load_phono3py_yaml", False))
args = parser.parse_args()
# Log level
@ -242,9 +243,9 @@ def start_phono3py(**argparse_control) -> tuple[argparse.Namespace, int]:
import spglib
try: # spglib.get_version() is deprecated.
print(f"Spglib version {spglib.spg_get_version()}")
print(f"Spglib version {spglib.spg_get_version()}") # type: ignore
except AttributeError:
print("Spglib version %d.%d.%d" % spglib.get_version())
print("Spglib version %d.%d.%d" % spglib.get_version()) # type: ignore
if deprecated:
show_deprecated_option_warnings(deprecated)
@ -252,7 +253,9 @@ def start_phono3py(**argparse_control) -> tuple[argparse.Namespace, int]:
return args, log_level
def read_phono3py_settings(args, argparse_control, log_level):
def read_phono3py_settings(
args: argparse.Namespace, argparse_control: dict, log_level: int
):
"""Read phono3py settings.
From:
@ -269,20 +272,20 @@ def read_phono3py_settings(args, argparse_control, log_level):
phono3py_conf_parser = Phono3pyConfParser(
filename=args.conf_filename,
args=args,
default_settings=argparse_control,
load_phono3py_yaml=load_phono3py_yaml,
)
cell_filename = args.filename[0]
else:
if is_file_phonopy_yaml(args.filename[0], keyword="phono3py"):
phono3py_conf_parser = Phono3pyConfParser(
args=args, default_settings=argparse_control
args=args, load_phono3py_yaml=load_phono3py_yaml
)
cell_filename = args.filename[0]
else: # args.filename[0] is assumed to be phono3py-conf file.
phono3py_conf_parser = Phono3pyConfParser(
filename=args.filename[0],
args=args,
default_settings=argparse_control,
load_phono3py_yaml=load_phono3py_yaml,
)
cell_filename = phono3py_conf_parser.settings.cell_filename
else:
@ -290,11 +293,11 @@ def read_phono3py_settings(args, argparse_control, log_level):
phono3py_conf_parser = Phono3pyConfParser(
args=args,
filename=args.conf_filename,
default_settings=argparse_control,
load_phono3py_yaml=load_phono3py_yaml,
)
else:
phono3py_conf_parser = Phono3pyConfParser(
args=args, default_settings=argparse_control
args=args, load_phono3py_yaml=load_phono3py_yaml
)
cell_filename = phono3py_conf_parser.settings.cell_filename
@ -304,7 +307,7 @@ def read_phono3py_settings(args, argparse_control, log_level):
return settings, confs_dict, cell_filename
def get_input_output_filenames_from_args(args):
def get_input_output_filenames_from_args(args: argparse.Namespace):
"""Return strings inserted to input and output filenames."""
if args.input_filename is not None:
warnings.warn(
@ -360,7 +363,7 @@ def get_cell_info(
return cell_info
def get_default_values(settings):
def get_default_values(settings: Phono3pySettings):
"""Set default values."""
# Brillouin zone integration: Tetrahedron (default) or smearing method
sigma = settings.sigma
@ -434,7 +437,9 @@ def get_default_values(settings):
return params
def check_supercell_in_yaml(cell_info, ph3, distance_to_A, log_level):
def check_supercell_in_yaml(
cell_info: dict, ph3: Phono3py, distance_to_A: float | None, log_level: int
):
"""Check consistency between generated cells and cells in yaml."""
if cell_info["phonopy_yaml"] is not None:
if distance_to_A is None:
@ -472,7 +477,11 @@ def check_supercell_in_yaml(cell_info, ph3, distance_to_A, log_level):
def init_phono3py(
settings, cell_info, interface_mode, symprec, log_level
settings: Phono3pySettings,
cell_info: dict,
interface_mode: str | None,
symprec: float,
log_level: int,
) -> tuple[Phono3py, dict]:
"""Initialize phono3py and update settings by default values."""
physical_units = get_calculator_physical_units(interface_mode)
@ -516,7 +525,7 @@ def init_phono3py(
return phono3py, updated_settings
def settings_to_grid_points(settings, bz_grid):
def settings_to_grid_points(settings: Phono3pySettings, bz_grid: BZGrid):
"""Read or set grid point indices."""
if settings.grid_addresses is not None:
grid_points = grid_addresses_to_grid_points(settings.grid_addresses, bz_grid)
@ -527,7 +536,7 @@ def settings_to_grid_points(settings, bz_grid):
return grid_points
def grid_addresses_to_grid_points(grid_addresses, bz_grid):
def grid_addresses_to_grid_points(grid_addresses: NDArray, bz_grid: BZGrid):
"""Return grid point indices from grid addresses."""
grid_points = [
get_grid_point_from_address(ga, bz_grid.D_diag) for ga in grid_addresses
@ -697,7 +706,7 @@ def _store_force_constants(ph3py: Phono3py, settings: Phono3pySettings, log_leve
print('fc2 was written into "fc2.hdf5".')
def _show_fc_calculator_not_found(log_level):
def _show_fc_calculator_not_found(log_level: int):
if log_level:
print("")
print(
@ -710,7 +719,9 @@ def _show_fc_calculator_not_found(log_level):
sys.exit(1)
def run_gruneisen_then_exit(phono3py, settings, output_filename, log_level):
def run_gruneisen_then_exit(
phono3py: Phono3py, settings: Phono3pySettings, output_filename: str, log_level: int
):
"""Run mode Grueneisen parameter calculation from fc3."""
if (
settings.mesh_numbers is None
@ -763,7 +774,11 @@ def run_gruneisen_then_exit(phono3py, settings, output_filename, log_level):
def run_jdos_then_exit(
phono3py: Phono3py, settings, updated_settings, output_filename, log_level
phono3py: Phono3py,
settings: Phono3pySettings,
updated_settings: dict,
output_filename: str | None,
log_level: int,
):
"""Run joint-DOS calculation."""
joint_dos = Phono3pyJointDos(
@ -801,7 +816,12 @@ def run_jdos_then_exit(
sys.exit(0)
def run_isotope_then_exit(phono3py, settings, updated_settings, log_level):
def run_isotope_then_exit(
phono3py: Phono3py,
settings: Phono3pySettings,
updated_settings: dict,
log_level: int,
):
"""Run isotope scattering calculation."""
mass_variances = settings.mass_variances
if settings.band_indices is not None:
@ -842,11 +862,11 @@ def run_isotope_then_exit(phono3py, settings, updated_settings, log_level):
def init_phph_interaction(
phono3py: Phono3py,
settings,
updated_settings,
input_filename,
output_filename,
log_level,
settings: Phono3pySettings,
updated_settings: dict,
input_filename: str | None,
output_filename: str | None,
log_level: int,
):
"""Initialize ph-ph interaction and phonons on grid."""
if log_level:

File diff suppressed because it is too large Load Diff

View File

@ -37,9 +37,5 @@ from phono3py.cui.phono3py_script import main
def run():
"""Run phono3py script."""
argparse_control = {
"fc_symmetry": False,
"is_nac": False,
"load_phono3py_yaml": False,
}
argparse_control = {"load_phono3py_yaml": False}
main(**argparse_control)

View File

@ -37,5 +37,5 @@ from phono3py.cui.phono3py_script import main
def run():
"""Run phono3py-load script."""
argparse_control = {"fc_symmetry": True, "is_nac": True, "load_phono3py_yaml": True}
argparse_control = {"load_phono3py_yaml": True}
main(**argparse_control)