mirror of https://github.com/Jittor/Jittor
add lock_scope and unlock_scope
This commit is contained in:
parent
ac4198d372
commit
9178b3459a
|
@ -7,14 +7,16 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
|
|
@ -7,7 +7,6 @@ import os, sys, shutil
|
|||
from .compiler import *
|
||||
from jittor_utils import run_cmd, get_version
|
||||
from jittor.dataset.utils import download_url_to_local
|
||||
from jittor.lock import jittor_lock
|
||||
|
||||
def search_file(dirs, name):
|
||||
for d in dirs:
|
||||
|
@ -374,8 +373,6 @@ def setup_mpi():
|
|||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
||||
jittor_lock.lock()
|
||||
|
||||
setup_mpi()
|
||||
setup_nccl()
|
||||
|
||||
|
@ -383,5 +380,3 @@ setup_cutt()
|
|||
setup_mkl()
|
||||
|
||||
setup_cuda_extern()
|
||||
|
||||
jittor_lock.unlock()
|
||||
|
|
|
@ -17,7 +17,7 @@ from ctypes.util import find_library
|
|||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
from jittor.lock import jittor_lock
|
||||
from . import lock
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
@ -555,6 +555,7 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
@lock.lock_scope()
|
||||
def compile_custom_ops(
|
||||
filenames,
|
||||
extra_flags="",
|
||||
|
@ -644,10 +645,10 @@ def compile_custom_ops(
|
|||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
if lib_path not in os.sys.path:
|
||||
os.sys.path.append(lib_path)
|
||||
jittor_lock.unlock()
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
jittor_lock.lock()
|
||||
# unlock scope when initialize
|
||||
with lock.unlock_scope():
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
return mod
|
||||
|
@ -801,20 +802,16 @@ import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
|||
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
|
||||
jittor_lock.lock()
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
jittor_lock.unlock()
|
||||
|
||||
jittor_path = find_jittor_path()
|
||||
check_debug_flags()
|
||||
|
||||
sys.path.append(cache_path)
|
||||
|
||||
jittor_lock.lock()
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
jittor_lock.unlock()
|
||||
|
||||
python_path = sys.executable
|
||||
py3_config_path = sys.executable+"-config"
|
||||
|
@ -856,13 +853,11 @@ make_cache_dir(os.path.join(cache_path, "jit"))
|
|||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
|
||||
jittor_lock.lock()
|
||||
# build cache_compile
|
||||
cc_flags += pybind_include
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
check_cache_compile()
|
||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
jittor_lock.unlock()
|
||||
|
||||
# check cuda
|
||||
has_cuda = 0
|
||||
|
@ -887,7 +882,6 @@ if has_cuda:
|
|||
return nvcc_flags
|
||||
nvcc_flags = convert_nvcc_flags(nvcc_flags)
|
||||
|
||||
jittor_lock.lock()
|
||||
# build core
|
||||
gen_jit_flags()
|
||||
gen_jit_tests()
|
||||
|
@ -977,4 +971,5 @@ flags.jittor_path = jittor_path
|
|||
flags.gdb_path = gdb_path
|
||||
flags.addr2line_path = addr2line_path
|
||||
flags.has_pybt = has_pybt
|
||||
jittor_lock.unlock()
|
||||
|
||||
core.set_lock_path(lock.lock_path)
|
||||
|
|
|
@ -15,7 +15,6 @@ from jittor.dataset.dataset import Dataset, dataset_root
|
|||
from jittor.dataset.utils import ensure_dir, download_url_to_local
|
||||
import jittor as jt
|
||||
import jittor.transform as trans
|
||||
from jittor.lock import jittor_lock
|
||||
|
||||
class MNIST(Dataset):
|
||||
def __init__(self, data_root=dataset_root+"/mnist_data/", train=True ,download=True, transform=None):
|
||||
|
@ -65,6 +64,4 @@ class MNIST(Dataset):
|
|||
|
||||
for url, md5 in resources:
|
||||
filename = url.rpartition('/')[2]
|
||||
jittor_lock.lock()
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
||||
jittor_lock.unlock()
|
||||
|
|
|
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||
import numpy as np
|
||||
from collections.abc import Sequence, Mapping
|
||||
from PIL import Image
|
||||
from .. import lock
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -36,7 +37,7 @@ def _progress():
|
|||
|
||||
return bar_update
|
||||
|
||||
|
||||
@lock.lock_scope()
|
||||
def download_url_to_local(url, filename, root_folder, md5):
|
||||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
|
|
|
@ -1,24 +1,63 @@
|
|||
import fcntl
|
||||
import os
|
||||
from jittor_utils import cache_path
|
||||
from jittor_utils import cache_path, LOG
|
||||
|
||||
class Lock:
|
||||
def __init__(self, filename):
|
||||
self.handle = open(filename, 'w')
|
||||
print(f'Create lock for {filename}, PID {os.getpid()}')
|
||||
LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
|
||||
self.is_locked = False
|
||||
|
||||
def lock(self):
|
||||
ret = fcntl.flock(self.handle, fcntl.LOCK_EX)
|
||||
print(f'Add lock success {ret}, PID {os.getpid()}')
|
||||
fcntl.flock(self.handle, fcntl.LOCK_EX)
|
||||
self.is_locked = True
|
||||
LOG.vv(f'LOCK PID: {os.getpid()}')
|
||||
|
||||
def unlock(self):
|
||||
ret = fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||
print(f'Release lock success {ret}, PID {os.getpid()}')
|
||||
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||
self.is_locked = False
|
||||
LOG.vv(f'UNLOCK PID: {os.getpid()}')
|
||||
|
||||
def __del__(self):
|
||||
self.handle.close()
|
||||
|
||||
lock_path = os.path.join(cache_path, "../jittor.lock")
|
||||
|
||||
class _base_scope:
|
||||
'''base_scope for support @xxx syntax'''
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
class lock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if not self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if not self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
class unlock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
|
||||
if not os.path.exists(lock_path):
|
||||
os.mknod(lock_path)
|
||||
LOG.i("Create lock file:", lock_path)
|
||||
try:
|
||||
os.mknod(lock_path)
|
||||
except:
|
||||
pass
|
||||
jittor_lock = Lock(lock_path)
|
|
@ -11,11 +11,11 @@ import os, sys
|
|||
import jittor as jt
|
||||
from pathlib import Path
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestLock(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
if os.environ.get('lock_full_test', '0') == '1':
|
||||
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
|
||||
cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
|
||||
|
@ -23,7 +23,7 @@ class TestLock(unittest.TestCase):
|
|||
cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
assert os.system(cmd) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -158,9 +158,10 @@ def find_cache_path():
|
|||
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
stderr=sp.PIPE)
|
||||
assert r.returncode == 0
|
||||
bs = r.stdout.decode()
|
||||
bs = r.stdout.decode().splitlines()
|
||||
for b in bs:
|
||||
if b.startswith("* "): break
|
||||
|
||||
cache_name = b[2:]
|
||||
for c in " (){}": cache_name = cache_name.replace(c, "_")
|
||||
except:
|
||||
|
@ -168,6 +169,7 @@ def find_cache_path():
|
|||
for name in cache_name.split("/"):
|
||||
dirs.insert(-1, name)
|
||||
os.environ["cache_name"] = cache_name
|
||||
LOG.v("cache_name", cache_name)
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
|
|
80
src/lock.cc
80
src/lock.cc
|
@ -1,66 +1,50 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// Wenyang Zhou <576825820@qq.com>.
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include "lock.h"
|
||||
#include "jit_compiler.h"
|
||||
#include "utils/cache_compile.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, cache_path);
|
||||
static int lock_fd = -1;
|
||||
|
||||
void lock_init(struct flock *lock, short type, short whence, off_t start, off_t len)
|
||||
{
|
||||
if (lock == NULL)
|
||||
return;
|
||||
|
||||
lock->l_type = type;
|
||||
lock->l_whence = whence;
|
||||
lock->l_start = start;
|
||||
lock->l_len = len;
|
||||
void set_lock_path(string path) {
|
||||
lock_fd = open(path.c_str(), O_RDWR);
|
||||
ASSERT(lock_fd >= 0);
|
||||
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
|
||||
}
|
||||
|
||||
int lock()
|
||||
{
|
||||
auto lock_path = jittor::jit_compiler::join(cache_path, "../jittor.lock");
|
||||
const char* lockfilepath = lock_path.c_str();
|
||||
int fd = open(lockfilepath, O_RDWR);
|
||||
if (fd < 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
struct flock lock;
|
||||
lock_init(&lock, F_WRLCK, SEEK_SET, 0, 0);
|
||||
if (fcntl(fd, F_SETLKW, &lock) != 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
// printf("Pid: %ld process lock to write the file.\n", (long)getpid());
|
||||
return 0;
|
||||
void lock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_WRLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "LOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
int unlock()
|
||||
{
|
||||
auto lock_path = jittor::jit_compiler::join(cache_path, "../jittor.lock");
|
||||
const char* lockfilepath = lock_path.c_str();
|
||||
int fd = open(lockfilepath, O_RDWR);
|
||||
if (fd < 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
struct flock lock;
|
||||
lock_init(&lock, F_UNLCK, SEEK_SET, 0, 0);
|
||||
if (fcntl(fd, F_SETLKW, &lock) != 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
// printf("Pid: %ld process release the file.\n", (long)getpid());
|
||||
return 0;
|
||||
void unlock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_UNLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "UNLOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
} // jittor
|
22
src/lock.h
22
src/lock.h
|
@ -1,20 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// Wenyang Zhou <576825820@qq.com>.
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
int lock();
|
||||
|
||||
int unlock();
|
||||
// @pyjt(set_lock_path)
|
||||
void set_lock_path(string path);
|
||||
|
||||
void lock();
|
||||
|
||||
void unlock();
|
||||
|
||||
struct lock_guard {
|
||||
inline lock_guard() { lock(); }
|
||||
inline ~lock_guard() { unlock(); }
|
||||
};
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -673,7 +673,7 @@ string OpCompiler::get_jit_src(Op* op) {
|
|||
else
|
||||
after_include_src += src;
|
||||
}
|
||||
ASSERT(file_exist(src_path));
|
||||
ASSERT(file_exist(src_path)) << src_path;
|
||||
LOGvvv << "Read from" << src_path;
|
||||
string src = read_all(src_path);
|
||||
ASSERT(src.size()) << "Source read failed:" << src_path;
|
||||
|
@ -946,7 +946,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
|
|||
}
|
||||
|
||||
jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
||||
lock();
|
||||
jittor::lock_guard lg;
|
||||
OpCompiler oc(op);
|
||||
string* src = &oc.src;
|
||||
string src_after_passes;
|
||||
|
@ -957,7 +957,6 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
|||
src = &src_after_passes;
|
||||
}
|
||||
auto ret = oc.compile(op->get_jit_key(), *src);
|
||||
unlock();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue