mirror of https://github.com/Jittor/Jittor
add test_lock
This commit is contained in:
parent
1bd4f9e944
commit
5bb3f60393
|
@ -644,8 +644,10 @@ def compile_custom_ops(
|
||||||
lib_path = os.path.join(cache_path, "custom_ops")
|
lib_path = os.path.join(cache_path, "custom_ops")
|
||||||
if lib_path not in os.sys.path:
|
if lib_path not in os.sys.path:
|
||||||
os.sys.path.append(lib_path)
|
os.sys.path.append(lib_path)
|
||||||
|
jittor_lock.unlock()
|
||||||
with jit_utils.import_scope(dlopen_flags):
|
with jit_utils.import_scope(dlopen_flags):
|
||||||
exec(f"import {gen_name}")
|
exec(f"import {gen_name}")
|
||||||
|
jittor_lock.lock()
|
||||||
mod = locals()[gen_name]
|
mod = locals()[gen_name]
|
||||||
if return_module:
|
if return_module:
|
||||||
return mod
|
return mod
|
||||||
|
@ -800,7 +802,6 @@ import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||||
|
|
||||||
jittor_lock.lock()
|
jittor_lock.lock()
|
||||||
print('jit_utils.try_import_jit_utils_core() 799 ...')
|
|
||||||
with jit_utils.import_scope(import_flags):
|
with jit_utils.import_scope(import_flags):
|
||||||
jit_utils.try_import_jit_utils_core()
|
jit_utils.try_import_jit_utils_core()
|
||||||
jittor_lock.unlock()
|
jittor_lock.unlock()
|
||||||
|
@ -811,7 +812,6 @@ check_debug_flags()
|
||||||
sys.path.append(cache_path)
|
sys.path.append(cache_path)
|
||||||
|
|
||||||
jittor_lock.lock()
|
jittor_lock.lock()
|
||||||
print('jit_utils.try_import_jit_utils_core() 810 ...')
|
|
||||||
with jit_utils.import_scope(import_flags):
|
with jit_utils.import_scope(import_flags):
|
||||||
jit_utils.try_import_jit_utils_core()
|
jit_utils.try_import_jit_utils_core()
|
||||||
jittor_lock.unlock()
|
jittor_lock.unlock()
|
||||||
|
@ -857,7 +857,6 @@ make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||||
|
|
||||||
jittor_lock.lock()
|
jittor_lock.lock()
|
||||||
print('check_cache_compile() 856 ...')
|
|
||||||
# build cache_compile
|
# build cache_compile
|
||||||
cc_flags += pybind_include
|
cc_flags += pybind_include
|
||||||
cc_flags += f" -I{jittor_path}/src "
|
cc_flags += f" -I{jittor_path}/src "
|
||||||
|
@ -889,7 +888,6 @@ if has_cuda:
|
||||||
nvcc_flags = convert_nvcc_flags(nvcc_flags)
|
nvcc_flags = convert_nvcc_flags(nvcc_flags)
|
||||||
|
|
||||||
jittor_lock.lock()
|
jittor_lock.lock()
|
||||||
print('build core 888 ...')
|
|
||||||
# build core
|
# build core
|
||||||
gen_jit_flags()
|
gen_jit_flags()
|
||||||
gen_jit_tests()
|
gen_jit_tests()
|
||||||
|
|
|
@ -5,16 +5,16 @@ from jittor_utils import cache_path
|
||||||
class Lock:
|
class Lock:
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.handle = open(filename, 'w')
|
self.handle = open(filename, 'w')
|
||||||
print(f'创建锁 {filename} PID {os.getpid()}')
|
print(f'Create lock for {filename}, PID {os.getpid()}')
|
||||||
|
|
||||||
def lock(self):
|
def lock(self):
|
||||||
ret = fcntl.flock(self.handle, fcntl.LOCK_EX)
|
ret = fcntl.flock(self.handle, fcntl.LOCK_EX)
|
||||||
print(f'加锁成功 {ret} PID {os.getpid()}')
|
print(f'Add lock success {ret}, PID {os.getpid()}')
|
||||||
|
|
||||||
def unlock(self):
|
def unlock(self):
|
||||||
ret = fcntl.flock(self.handle, fcntl.LOCK_UN)
|
ret = fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||||
print(f'释放锁成功 {ret} PID {os.getpid()}')
|
print(f'Release lock success {ret}, PID {os.getpid()}')
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors:
|
||||||
|
# Wenyang Zhou <576825820@qq.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import os, sys
|
||||||
|
import jittor as jt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||||
|
class TestLock(unittest.TestCase):
|
||||||
|
def test(self):
|
||||||
|
mpi = jt.compile_extern.mpi
|
||||||
|
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||||
|
if os.environ.get('lock_full_test', '0') == '1':
|
||||||
|
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
|
||||||
|
cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
|
||||||
|
else:
|
||||||
|
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
|
||||||
|
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
|
||||||
|
print("run cmd", cmd)
|
||||||
|
jt.compiler.run_cmd(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue