add test_lock

This commit is contained in:
zhouwy2115 2020-04-08 21:32:35 +08:00 committed by Dun Liang
parent 1bd4f9e944
commit 5bb3f60393
3 changed files with 36 additions and 8 deletions

View File

@ -644,8 +644,10 @@ def compile_custom_ops(
lib_path = os.path.join(cache_path, "custom_ops")
if lib_path not in os.sys.path:
os.sys.path.append(lib_path)
jittor_lock.unlock()
with jit_utils.import_scope(dlopen_flags):
exec(f"import {gen_name}")
jittor_lock.lock()
mod = locals()[gen_name]
if return_module:
return mod
@ -800,7 +802,6 @@ import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
jittor_lock.lock()
print('jit_utils.try_import_jit_utils_core() 799 ...')
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
jittor_lock.unlock()
@ -811,7 +812,6 @@ check_debug_flags()
sys.path.append(cache_path)
jittor_lock.lock()
print('jit_utils.try_import_jit_utils_core() 810 ...')
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
jittor_lock.unlock()
@ -857,7 +857,6 @@ make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))
jittor_lock.lock()
print('check_cache_compile() 856 ...')
# build cache_compile
cc_flags += pybind_include
cc_flags += f" -I{jittor_path}/src "
@ -889,7 +888,6 @@ if has_cuda:
nvcc_flags = convert_nvcc_flags(nvcc_flags)
jittor_lock.lock()
print('build core 888 ...')
# build core
gen_jit_flags()
gen_jit_tests()

View File

@ -5,15 +5,15 @@ from jittor_utils import cache_path
class Lock:
def __init__(self, filename):
self.handle = open(filename, 'w')
print(f'创建锁 {filename} PID {os.getpid()}')
print(f'Create lock for {filename}, PID {os.getpid()}')
def lock(self):
ret = fcntl.flock(self.handle, fcntl.LOCK_EX)
print(f'加锁成功 {ret} PID {os.getpid()}')
print(f'Add lock success {ret}, PID {os.getpid()}')
def unlock(self):
ret = fcntl.flock(self.handle, fcntl.LOCK_UN)
print(f'释放锁成功 {ret} PID {os.getpid()}')
print(f'Release lock success {ret}, PID {os.getpid()}')
def __del__(self):
self.handle.close()

View File

@ -0,0 +1,30 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
from pathlib import Path
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestLock(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
else:
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
if __name__ == "__main__":
unittest.main()