Implemented shortest three atom distance measure

This commit is contained in:
Atsushi Togo 2023-12-24 11:49:39 +09:00
parent e4ede39edd
commit 49c7be6d37
3 changed files with 104 additions and 14 deletions

View File

@ -34,7 +34,8 @@
# POSSIBILITY OF SUCH DAMAGE.
import warnings
from typing import Optional
from collections.abc import Sequence
from typing import Optional, Union
import numpy as np
from phonopy.exception import ForceCalculatorRequiredError
@ -1041,7 +1042,7 @@ class Phono3py:
return self._bz_grid.D_diag
@mesh_numbers.setter
def mesh_numbers(self, mesh_numbers):
def mesh_numbers(self, mesh_numbers: Union[int, float, Sequence, np.ndarray]):
self._set_mesh_numbers(mesh_numbers)
@property
@ -2555,7 +2556,10 @@ class Phono3py:
else:
return pmat
def _set_mesh_numbers(self, mesh):
def _set_mesh_numbers(
self,
mesh: Union[int, float, Sequence, np.ndarray],
):
# initialization related to mesh
self._interaction = None

View File

@ -179,6 +179,11 @@ class Interaction:
self._masses = np.array(self._primitive.masses, dtype="double")
self._p2s = np.array(self._primitive.p2s_map, dtype="int_")
self._s2p = np.array(self._primitive.s2p_map, dtype="int_")
n_satom, n_patom, _ = self._multi.shape
self._all_shortest = np.zeros(
(n_patom, n_satom, n_satom), dtype="byte", order="C"
)
self._get_all_shortest()
def run(
self, lang: Literal["C", "Python"] = "C", g_zero: Optional[np.ndarray] = None
@ -987,6 +992,46 @@ class Interaction:
self._eigenvectors_at_gamma = self._eigenvectors[gp_Gamma].copy()
self._phonon_done[gp_Gamma] = 0
def _get_all_shortest(self):
"""Return array indicating distances among three atoms are all shortest.
multi.shape = (n_satom, n_patom)
svecs : distance with respect to primitive cell basis
perms.shape = (n_pure_trans, n_satom)
"""
svecs = self._svecs
multi = self._multi
n_satom, n_patom, _ = multi.shape
perms = self._primitive.atomic_permutations
s2pp_map = [self._primitive.p2p_map[i] for i in self._s2p]
lattice = self._primitive.cell
for i_patom in range(n_patom):
for j_atom in range(n_satom):
j_patom = s2pp_map[j_atom]
i_perm = np.where(perms[:, j_atom] == self._p2s[j_patom])[0][0]
for k_atom in range(n_satom):
initial_vec = (
svecs[multi[k_atom, i_patom, 1]]
- svecs[multi[j_atom, i_patom, 1]]
)
d_jk_shortest = np.linalg.norm(initial_vec @ lattice)
for j_m, k_m in np.ndindex(
(multi[j_atom, i_patom, 0], multi[k_atom, i_patom, 0])
):
vec_ij = svecs[multi[j_atom, i_patom, 1] + j_m]
vec_ik = svecs[multi[k_atom, i_patom, 1] + k_m]
d_jk_attempt = np.linalg.norm((vec_ik - vec_ij) @ lattice)
if d_jk_attempt < d_jk_shortest:
d_jk_shortest = d_jk_attempt
k_atom_mapped = perms[i_perm, k_atom]
d_jk_mapped = np.linalg.norm(
svecs[multi[k_atom_mapped, j_patom, 1]] @ lattice
)
if abs(d_jk_mapped - d_jk_shortest) < self._symprec:
self._all_shortest[i_patom, j_atom, k_atom] = 1
def all_bands_exist(interaction: Interaction):
"""Return if all bands are selected or not."""

View File

@ -1,6 +1,12 @@
"""Test Interaction class."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Literal, Optional, Union
import numpy as np
import pytest
from phonopy.structure.cells import get_smallest_vectors
from phono3py import Phono3py
from phono3py.phonon3.interaction import Interaction
@ -153,8 +159,8 @@ itr_RTA_AlN_r0_ave = [
]
@pytest.mark.parametrize("lang", ["C", "Py"])
def test_interaction_RTA_si(si_pbesol, lang):
@pytest.mark.parametrize("lang", ["C", "Python"])
def test_interaction_RTA_si(si_pbesol: Phono3py, lang: Literal["C", "Python"]):
"""Test interaction_strength of Si."""
itr = _get_irt(si_pbesol, [4, 4, 4])
itr.set_grid_point(1)
@ -166,7 +172,7 @@ def test_interaction_RTA_si(si_pbesol, lang):
)
def test_interaction_RTA_AlN(aln_lda):
def test_interaction_RTA_AlN(aln_lda: Phono3py):
"""Test interaction_strength of AlN."""
itr = _get_irt(aln_lda, [7, 7, 7])
itr.set_grid_point(1)
@ -177,7 +183,7 @@ def test_interaction_RTA_AlN(aln_lda):
)
def test_interaction_RTA_AlN_r0_ave(aln_lda):
def test_interaction_RTA_AlN_r0_ave(aln_lda: Phono3py):
"""Test interaction_strength of AlN."""
itr = _get_irt(aln_lda, [7, 7, 7], make_r0_average=True)
itr.set_grid_point(1)
@ -271,7 +277,7 @@ def test_interaction_run_phonon_solver_at_gamma_NaCl(nacl_pbe: Phono3py):
)
def test_phonon_solver_expand_RTA_si(si_pbesol):
def test_phonon_solver_expand_RTA_si(si_pbesol: Phono3py):
"""Test phonon solver with eigenvector rotation of Si.
Eigenvectors can be different but frequencies must be almost the same.
@ -286,19 +292,54 @@ def test_phonon_solver_expand_RTA_si(si_pbesol):
np.testing.assert_allclose(freqs, freqs_expanded, rtol=0, atol=1e-6)
def test_get_all_shortest(aln_lda: Phono3py):
"""Test Interaction._get_all_shortest."""
ph3 = aln_lda
ph3.mesh_numbers = 30
itr = Interaction(
ph3.primitive,
ph3.grid,
ph3.primitive_symmetry,
cutoff_frequency=1e-5,
)
s_svecs, s_multi = get_smallest_vectors(
ph3.supercell.cell,
ph3.supercell.scaled_positions,
ph3.supercell.scaled_positions,
store_dense_svecs=True,
)
s_lattice = ph3.supercell.cell
p_lattice = itr.primitive.cell
shortests = itr._all_shortest
svecs, multi, _, _, _ = itr.get_primitive_and_supercell_correspondence()
n_satom, n_patom, _ = multi.shape
for i, j, k in np.ndindex((n_patom, n_satom, n_satom)):
d_jk_shortest = np.linalg.norm(s_svecs[s_multi[j, k, 1]] @ s_lattice)
is_found = 0
for m_j, m_k in np.ndindex((multi[j, i, 0], multi[k, i, 0])):
vec_ij = svecs[multi[j, i, 1] + m_j]
vec_ik = svecs[multi[k, i, 1] + m_k]
vec_jk = vec_ik - vec_ij
d_jk = np.linalg.norm(vec_jk @ p_lattice)
if abs(d_jk - d_jk_shortest) < ph3.symmetry.tolerance:
is_found = 1
break
assert shortests[i, j, k] == is_found
def _get_irt(
ph3: Phono3py,
mesh,
nac_params=None,
solve_dynamical_matrices=True,
make_r0_average=False,
mesh: Union[int, float, Sequence, np.ndarray],
nac_params: Optional[dict] = None,
solve_dynamical_matrices: bool = True,
make_r0_average: bool = False,
):
ph3.mesh_numbers = mesh
itr = Interaction(
ph3.primitive,
ph3.grid,
ph3.primitive_symmetry,
ph3.fc3,
fc3=ph3.fc3,
make_r0_average=make_r0_average,
cutoff_frequency=1e-4,
)
@ -320,7 +361,7 @@ def _get_irt(
return itr
def _show(itr):
def _show(itr: Interaction):
itr_vals = itr.interaction_strength.sum(axis=(1, 2, 3))
for i, v in enumerate(itr_vals):
print("%e, " % v, end="")