add lock_scope and unlock_scope

This commit is contained in:
Dun Liang 2020-04-10 16:10:01 +08:00
parent ac4198d372
commit 9178b3459a
11 changed files with 123 additions and 103 deletions

View File

@ -7,14 +7,16 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
from . import compiler from . import lock
from .compiler import LOG, has_cuda with lock.lock_scope():
from .compiler import compile_custom_ops, compile_custom_op from . import compiler
import jittor_core as core from .compiler import LOG, has_cuda
from jittor_core import * from .compiler import compile_custom_ops, compile_custom_op
from jittor_core.ops import * import jittor_core as core
from . import compile_extern from jittor_core import *
from .compile_extern import mkl_ops from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
import contextlib import contextlib
import numpy as np import numpy as np

View File

@ -7,7 +7,6 @@ import os, sys, shutil
from .compiler import * from .compiler import *
from jittor_utils import run_cmd, get_version from jittor_utils import run_cmd, get_version
from jittor.dataset.utils import download_url_to_local from jittor.dataset.utils import download_url_to_local
from jittor.lock import jittor_lock
def search_file(dirs, name): def search_file(dirs, name):
for d in dirs: for d in dirs:
@ -374,8 +373,6 @@ def setup_mpi():
LOG.vv("Get mpi: "+str(mpi.__dict__.keys())) LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys())) LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
jittor_lock.lock()
setup_mpi() setup_mpi()
setup_nccl() setup_nccl()
@ -383,5 +380,3 @@ setup_cutt()
setup_mkl() setup_mkl()
setup_cuda_extern() setup_cuda_extern()
jittor_lock.unlock()

View File

@ -17,7 +17,7 @@ from ctypes.util import find_library
import jittor_utils as jit_utils import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler from . import pyjt_compiler
from jittor.lock import jittor_lock from . import lock
def find_jittor_path(): def find_jittor_path():
return os.path.dirname(__file__) return os.path.dirname(__file__)
@ -555,6 +555,7 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname]) m = compile_custom_ops([hname, ccname])
return getattr(m, op_name) return getattr(m, op_name)
@lock.lock_scope()
def compile_custom_ops( def compile_custom_ops(
filenames, filenames,
extra_flags="", extra_flags="",
@ -644,10 +645,10 @@ def compile_custom_ops(
lib_path = os.path.join(cache_path, "custom_ops") lib_path = os.path.join(cache_path, "custom_ops")
if lib_path not in os.sys.path: if lib_path not in os.sys.path:
os.sys.path.append(lib_path) os.sys.path.append(lib_path)
jittor_lock.unlock() # unlock scope when initialize
with jit_utils.import_scope(dlopen_flags): with lock.unlock_scope():
exec(f"import {gen_name}") with jit_utils.import_scope(dlopen_flags):
jittor_lock.lock() exec(f"import {gen_name}")
mod = locals()[gen_name] mod = locals()[gen_name]
if return_module: if return_module:
return mod return mod
@ -801,20 +802,16 @@ import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL # import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
jittor_lock.lock()
with jit_utils.import_scope(import_flags): with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core() jit_utils.try_import_jit_utils_core()
jittor_lock.unlock()
jittor_path = find_jittor_path() jittor_path = find_jittor_path()
check_debug_flags() check_debug_flags()
sys.path.append(cache_path) sys.path.append(cache_path)
jittor_lock.lock()
with jit_utils.import_scope(import_flags): with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core() jit_utils.try_import_jit_utils_core()
jittor_lock.unlock()
python_path = sys.executable python_path = sys.executable
py3_config_path = sys.executable+"-config" py3_config_path = sys.executable+"-config"
@ -856,13 +853,11 @@ make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files")) make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen")) make_cache_dir(os.path.join(cache_path, "gen"))
jittor_lock.lock()
# build cache_compile # build cache_compile
cc_flags += pybind_include cc_flags += pybind_include
cc_flags += f" -I{jittor_path}/src " cc_flags += f" -I{jittor_path}/src "
check_cache_compile() check_cache_compile()
LOG.v(f"Get cache_compile: {jit_utils.cc}") LOG.v(f"Get cache_compile: {jit_utils.cc}")
jittor_lock.unlock()
# check cuda # check cuda
has_cuda = 0 has_cuda = 0
@ -887,7 +882,6 @@ if has_cuda:
return nvcc_flags return nvcc_flags
nvcc_flags = convert_nvcc_flags(nvcc_flags) nvcc_flags = convert_nvcc_flags(nvcc_flags)
jittor_lock.lock()
# build core # build core
gen_jit_flags() gen_jit_flags()
gen_jit_tests() gen_jit_tests()
@ -977,4 +971,5 @@ flags.jittor_path = jittor_path
flags.gdb_path = gdb_path flags.gdb_path = gdb_path
flags.addr2line_path = addr2line_path flags.addr2line_path = addr2line_path
flags.has_pybt = has_pybt flags.has_pybt = has_pybt
jittor_lock.unlock()
core.set_lock_path(lock.lock_path)

View File

@ -15,7 +15,6 @@ from jittor.dataset.dataset import Dataset, dataset_root
from jittor.dataset.utils import ensure_dir, download_url_to_local from jittor.dataset.utils import ensure_dir, download_url_to_local
import jittor as jt import jittor as jt
import jittor.transform as trans import jittor.transform as trans
from jittor.lock import jittor_lock
class MNIST(Dataset): class MNIST(Dataset):
def __init__(self, data_root=dataset_root+"/mnist_data/", train=True ,download=True, transform=None): def __init__(self, data_root=dataset_root+"/mnist_data/", train=True ,download=True, transform=None):
@ -65,6 +64,4 @@ class MNIST(Dataset):
for url, md5 in resources: for url, md5 in resources:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
jittor_lock.lock()
download_url_to_local(url, filename, self.data_root, md5) download_url_to_local(url, filename, self.data_root, md5)
jittor_lock.unlock()

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
import numpy as np import numpy as np
from collections.abc import Sequence, Mapping from collections.abc import Sequence, Mapping
from PIL import Image from PIL import Image
from .. import lock
def ensure_dir(dir_path): def ensure_dir(dir_path):
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
@ -36,7 +37,7 @@ def _progress():
return bar_update return bar_update
@lock.lock_scope()
def download_url_to_local(url, filename, root_folder, md5): def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder) ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename) file_path = os.path.join(root_folder, filename)

View File

@ -1,24 +1,63 @@
import fcntl import fcntl
import os import os
from jittor_utils import cache_path from jittor_utils import cache_path, LOG
class Lock: class Lock:
def __init__(self, filename): def __init__(self, filename):
self.handle = open(filename, 'w') self.handle = open(filename, 'w')
print(f'Create lock for {filename}, PID {os.getpid()}') LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
self.is_locked = False
def lock(self): def lock(self):
ret = fcntl.flock(self.handle, fcntl.LOCK_EX) fcntl.flock(self.handle, fcntl.LOCK_EX)
print(f'Add lock success {ret}, PID {os.getpid()}') self.is_locked = True
LOG.vv(f'LOCK PID: {os.getpid()}')
def unlock(self): def unlock(self):
ret = fcntl.flock(self.handle, fcntl.LOCK_UN) fcntl.flock(self.handle, fcntl.LOCK_UN)
print(f'Release lock success {ret}, PID {os.getpid()}') self.is_locked = False
LOG.vv(f'UNLOCK PID: {os.getpid()}')
def __del__(self): def __del__(self):
self.handle.close() self.handle.close()
lock_path = os.path.join(cache_path, "../jittor.lock")
class _base_scope:
'''base_scope for support @xxx syntax'''
def __enter__(self): pass
def __exit__(self, *exc): pass
def __call__(self, func):
def inner(*args, **kw):
with self:
ret = func(*args, **kw)
return ret
return inner
class lock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if not self.is_locked:
jittor_lock.lock()
def __exit__(self, *exc):
if not self.is_locked:
jittor_lock.unlock()
class unlock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if self.is_locked:
jittor_lock.unlock()
def __exit__(self, *exc):
if self.is_locked:
jittor_lock.lock()
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
if not os.path.exists(lock_path): if not os.path.exists(lock_path):
os.mknod(lock_path) LOG.i("Create lock file:", lock_path)
jittor_lock = Lock(lock_path) try:
os.mknod(lock_path)
except:
pass
jittor_lock = Lock(lock_path)

View File

@ -11,11 +11,11 @@ import os, sys
import jittor as jt import jittor as jt
from pathlib import Path from pathlib import Path
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found") @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestLock(unittest.TestCase): class TestLock(unittest.TestCase):
def test(self): def test(self):
mpi = jt.compile_extern.mpi mpi = jt.compile_extern.mpi
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun') mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
if os.environ.get('lock_full_test', '0') == '1': if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock") cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
@ -23,7 +23,7 @@ class TestLock(unittest.TestCase):
cache_path = os.path.join(str(Path.home()), ".cache", "jittor") cache_path = os.path.join(str(Path.home()), ".cache", "jittor")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example"
print("run cmd", cmd) print("run cmd", cmd)
jt.compiler.run_cmd(cmd) assert os.system(cmd) == 0
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -158,9 +158,10 @@ def find_cache_path():
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE, r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE) stderr=sp.PIPE)
assert r.returncode == 0 assert r.returncode == 0
bs = r.stdout.decode() bs = r.stdout.decode().splitlines()
for b in bs: for b in bs:
if b.startswith("* "): break if b.startswith("* "): break
cache_name = b[2:] cache_name = b[2:]
for c in " (){}": cache_name = cache_name.replace(c, "_") for c in " (){}": cache_name = cache_name.replace(c, "_")
except: except:
@ -168,6 +169,7 @@ def find_cache_path():
for name in cache_name.split("/"): for name in cache_name.split("/"):
dirs.insert(-1, name) dirs.insert(-1, name)
os.environ["cache_name"] = cache_name os.environ["cache_name"] = cache_name
LOG.v("cache_name", cache_name)
for d in dirs: for d in dirs:
path = os.path.join(path, d) path = os.path.join(path, d)
if not os.path.isdir(path): if not os.path.isdir(path):

View File

@ -1,66 +1,50 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 Jittor. Authors: // Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>. // Wenyang Zhou <576825820@qq.com>
// Wenyang Zhou <576825820@qq.com>. // Dun Liang <randonlang@gmail.com>
// All Rights Reserved. // All Rights Reserved.
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include "lock.h" #include "lock.h"
#include "jit_compiler.h"
#include "utils/cache_compile.h"
namespace jittor { namespace jittor {
DECLARE_FLAG(string, cache_path); static int lock_fd = -1;
void lock_init(struct flock *lock, short type, short whence, off_t start, off_t len) void set_lock_path(string path) {
{ lock_fd = open(path.c_str(), O_RDWR);
if (lock == NULL) ASSERT(lock_fd >= 0);
return; LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
lock->l_type = type;
lock->l_whence = whence;
lock->l_start = start;
lock->l_len = len;
} }
int lock() void lock() {
{ ASSERT(lock_fd >= 0);
auto lock_path = jittor::jit_compiler::join(cache_path, "../jittor.lock"); struct flock lock = {
const char* lockfilepath = lock_path.c_str(); .l_type = F_WRLCK,
int fd = open(lockfilepath, O_RDWR); .l_whence = SEEK_SET,
if (fd < 0) .l_start = 0,
{ .l_len = 0
return -1; };
} ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
struct flock lock; LOGvv << "LOCK Pid:" << getpid();
lock_init(&lock, F_WRLCK, SEEK_SET, 0, 0);
if (fcntl(fd, F_SETLKW, &lock) != 0)
{
return -1;
}
// printf("Pid: %ld process lock to write the file.\n", (long)getpid());
return 0;
} }
int unlock() void unlock() {
{ ASSERT(lock_fd >= 0);
auto lock_path = jittor::jit_compiler::join(cache_path, "../jittor.lock"); struct flock lock = {
const char* lockfilepath = lock_path.c_str(); .l_type = F_UNLCK,
int fd = open(lockfilepath, O_RDWR); .l_whence = SEEK_SET,
if (fd < 0) .l_start = 0,
{ .l_len = 0
return -1; };
} ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
struct flock lock; LOGvv << "UNLOCK Pid:" << getpid();
lock_init(&lock, F_UNLCK, SEEK_SET, 0, 0);
if (fcntl(fd, F_SETLKW, &lock) != 0)
{
return -1;
}
// printf("Pid: %ld process release the file.\n", (long)getpid());
return 0;
} }
} // jittor } // jittor

View File

@ -1,20 +1,26 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2020 Jittor. Authors: // Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>. // Wenyang Zhou <576825820@qq.com>
// Wenyang Zhou <576825820@qq.com>. // Dun Liang <randonlang@gmail.com>
// All Rights Reserved. // All Rights Reserved.
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#pragma once #pragma once
#include <stdio.h> #include "common.h"
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
namespace jittor { namespace jittor {
int lock();
int unlock(); // @pyjt(set_lock_path)
void set_lock_path(string path);
void lock();
void unlock();
struct lock_guard {
inline lock_guard() { lock(); }
inline ~lock_guard() { unlock(); }
};
} // jittor } // jittor

View File

@ -673,7 +673,7 @@ string OpCompiler::get_jit_src(Op* op) {
else else
after_include_src += src; after_include_src += src;
} }
ASSERT(file_exist(src_path)); ASSERT(file_exist(src_path)) << src_path;
LOGvvv << "Read from" << src_path; LOGvvv << "Read from" << src_path;
string src = read_all(src_path); string src = read_all(src_path);
ASSERT(src.size()) << "Source read failed:" << src_path; ASSERT(src.size()) << "Source read failed:" << src_path;
@ -946,7 +946,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
} }
jit_op_entry_t OpCompiler::do_compile(Op* op) { jit_op_entry_t OpCompiler::do_compile(Op* op) {
lock(); jittor::lock_guard lg;
OpCompiler oc(op); OpCompiler oc(op);
string* src = &oc.src; string* src = &oc.src;
string src_after_passes; string src_after_passes;
@ -957,7 +957,6 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src = &src_after_passes; src = &src_after_passes;
} }
auto ret = oc.compile(op->get_jit_key(), *src); auto ret = oc.compile(op->get_jit_key(), *src);
unlock();
return ret; return ret;
} }