mirror of https://github.com/Jittor/Jittor
Merge branch 'new_lock' of https://github.com/Jittor/jittor
This commit is contained in:
commit
d2ae3c05ff
11
README.cn.md
11
README.cn.md
|
@ -16,6 +16,8 @@ Jittor前端语言为Python。前端使用了模块化的设计,这是目前
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
|
11
README.md
11
README.md
|
@ -16,6 +16,8 @@ The following example shows how to model a two-layer neural network step by step
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
|
|
@ -21,6 +21,8 @@ The following example shows how to model a two-layer neural network step by step
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -38,13 +40,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
|
|
@ -7,14 +7,16 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
|
|
@ -17,6 +17,7 @@ from ctypes.util import find_library
|
|||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
from . import lock
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
@ -518,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
"""
|
||||
return jit_src
|
||||
|
||||
@lock.lock_scope()
|
||||
def compile_custom_op(header, source, op_name, warp=True):
|
||||
"""Compile a single custom op
|
||||
header: code of op header, not path
|
||||
|
@ -554,6 +556,7 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
@lock.lock_scope()
|
||||
def compile_custom_ops(
|
||||
filenames,
|
||||
extra_flags="",
|
||||
|
@ -643,8 +646,10 @@ def compile_custom_ops(
|
|||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
if lib_path not in os.sys.path:
|
||||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
# unlock scope when initialize
|
||||
with lock.unlock_scope():
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
return mod
|
||||
|
@ -800,6 +805,7 @@ dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
|||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
jittor_path = find_jittor_path()
|
||||
check_debug_flags()
|
||||
|
||||
|
@ -966,3 +972,5 @@ flags.jittor_path = jittor_path
|
|||
flags.gdb_path = gdb_path
|
||||
flags.addr2line_path = addr2line_path
|
||||
flags.has_pybt = has_pybt
|
||||
|
||||
core.set_lock_path(lock.lock_path)
|
||||
|
|
|
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||
import numpy as np
|
||||
from collections.abc import Sequence, Mapping
|
||||
from PIL import Image
|
||||
from .. import lock
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -36,7 +37,7 @@ def _progress():
|
|||
|
||||
return bar_update
|
||||
|
||||
|
||||
@lock.lock_scope()
|
||||
def download_url_to_local(url, filename, root_folder, md5):
|
||||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import fcntl
|
||||
import os
|
||||
from jittor_utils import cache_path, LOG
|
||||
|
||||
class Lock:
|
||||
def __init__(self, filename):
|
||||
self.handle = open(filename, 'w')
|
||||
LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
|
||||
self.is_locked = False
|
||||
|
||||
def lock(self):
|
||||
fcntl.flock(self.handle, fcntl.LOCK_EX)
|
||||
self.is_locked = True
|
||||
LOG.vv(f'LOCK PID: {os.getpid()}')
|
||||
|
||||
def unlock(self):
|
||||
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||
self.is_locked = False
|
||||
LOG.vv(f'UNLOCK PID: {os.getpid()}')
|
||||
|
||||
def __del__(self):
|
||||
self.handle.close()
|
||||
|
||||
|
||||
class _base_scope:
|
||||
'''base_scope for support @xxx syntax'''
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
class lock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if not self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if not self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
class unlock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
|
||||
if not os.path.exists(lock_path):
|
||||
LOG.i("Create lock file:", lock_path)
|
||||
try:
|
||||
os.mknod(lock_path)
|
||||
except:
|
||||
pass
|
||||
jittor_lock = Lock(lock_path)
|
|
@ -5,23 +5,33 @@
|
|||
# ***************************************************************
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
import unittest, os
|
||||
|
||||
suffix = "__main__.py"
|
||||
assert __file__.endswith(suffix)
|
||||
test_dir = __file__[:-len(suffix)]
|
||||
import os
|
||||
|
||||
skip_l = int(os.environ.get("test_skip_l", "0"))
|
||||
skip_r = int(os.environ.get("test_skip_r", "1000000"))
|
||||
test_only = None
|
||||
if "test_only" in os.environ:
|
||||
test_only = set(os.environ.get("test_only").split(","))
|
||||
|
||||
test_files = os.listdir(test_dir)
|
||||
for test_file in test_files:
|
||||
test_files = sorted(test_files)
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
for _, test_file in enumerate(test_files):
|
||||
if not test_file.startswith("test_"):
|
||||
continue
|
||||
if _ < skip_l or _ > skip_r:
|
||||
continue
|
||||
test_name = test_file.split(".")[0]
|
||||
exec(f"from . import {test_name}")
|
||||
test_mod = globals()[test_name]
|
||||
print(test_name)
|
||||
for i in dir(test_mod):
|
||||
obj = getattr(test_mod, i)
|
||||
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
|
||||
globals()[test_name+"_"+i] = obj
|
||||
if test_only and test_name not in test_only:
|
||||
continue
|
||||
|
||||
unittest.main()
|
||||
print("Add Test", _, test_name)
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
|
||||
"jittor.test."+test_name))
|
||||
|
||||
unittest.TextTestRunner(verbosity=3).run(suite)
|
|
@ -18,6 +18,7 @@ import pickle as pk
|
|||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch.nn import MaxPool2d, Sequential
|
||||
except:
|
||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
|
||||
class TestClone(unittest.TestCase):
|
||||
def test(self):
|
||||
jt.clean()
|
||||
b = a = jt.array(1)
|
||||
for i in range(10):
|
||||
b = b.clone()
|
||||
|
|
|
@ -18,11 +18,12 @@ def test_cuda(use_cuda=1):
|
|||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
class TestCuda(unittest.TestCase):
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_flags(self):
|
||||
with jt.var_scope(use_cuda=1):
|
||||
a = jt.random((10, 10))
|
||||
a.sync()
|
||||
a = jt.random((10, 10))
|
||||
a.sync()
|
||||
|
||||
@jt.flag_scope(use_cuda=2)
|
||||
def test_no_cuda_op(self):
|
||||
no_cuda_op = jt.compile_custom_op("""
|
||||
struct NoCudaOp : Op {
|
||||
|
@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase):
|
|||
""",
|
||||
"no_cuda")
|
||||
# force use cuda
|
||||
with jt.var_scope(use_cuda=2):
|
||||
a = no_cuda_op([3,4,5], 'float')
|
||||
expect_error(lambda: a())
|
||||
a = no_cuda_op([3,4,5], 'float')
|
||||
expect_error(lambda: a())
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_custom_op(self):
|
||||
my_op = jt.compile_custom_op("""
|
||||
struct MyCudaOp : Op {
|
||||
|
@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase):
|
|||
#endif // JIT
|
||||
""",
|
||||
"my_cuda")
|
||||
with jt.var_scope(use_cuda=1):
|
||||
a = my_op([3,4,5], 'float')
|
||||
na = a.data
|
||||
a = my_op([3,4,5], 'float')
|
||||
na = a.data
|
||||
assert a.shape == [3,4,5] and a.dtype == 'float'
|
||||
assert (-na.flatten() == range(3*4*5)).all(), na
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase):
|
|||
@jt.flag_scope(use_cuda=1)
|
||||
def test(self):
|
||||
t = cutt_ops.cutt_test("213")
|
||||
jt.sync_all(True)
|
||||
print(t.data)
|
||||
assert t.data == 123
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,27 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
from pathlib import Path
|
||||
|
||||
class TestLock(unittest.TestCase):
|
||||
def test(self):
|
||||
if os.environ.get('lock_full_test', '0') == '1':
|
||||
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
|
||||
assert os.system(f"rm -rf {cache_path}") == 0
|
||||
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
|
||||
else:
|
||||
cmd = f"{sys.executable} -m jittor.test.test_example"
|
||||
print("run cmd twice", cmd)
|
||||
assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
class TestMiscIssue(unittest.TestCase):
|
||||
def test_issue4(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
@ -42,6 +43,7 @@ b.sync()
|
|||
|
||||
def test_mkl_conflict1(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
@ -67,6 +69,7 @@ m(torch.rand(*nchw))
|
|||
|
||||
def test_mkl_conflict2(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
|
|
@ -158,9 +158,10 @@ def find_cache_path():
|
|||
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
stderr=sp.PIPE)
|
||||
assert r.returncode == 0
|
||||
bs = r.stdout.decode()
|
||||
bs = r.stdout.decode().splitlines()
|
||||
for b in bs:
|
||||
if b.startswith("* "): break
|
||||
|
||||
cache_name = b[2:]
|
||||
for c in " (){}": cache_name = cache_name.replace(c, "_")
|
||||
except:
|
||||
|
@ -168,10 +169,14 @@ def find_cache_path():
|
|||
for name in cache_name.split("/"):
|
||||
dirs.insert(-1, name)
|
||||
os.environ["cache_name"] = cache_name
|
||||
LOG.v("cache_name", cache_name)
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
try:
|
||||
os.mkdir(path)
|
||||
except:
|
||||
pass
|
||||
assert os.path.isdir(path)
|
||||
if path not in sys.path:
|
||||
sys.path.append(path)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include "lock.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static int lock_fd = -1;
|
||||
|
||||
void set_lock_path(string path) {
|
||||
lock_fd = open(path.c_str(), O_RDWR);
|
||||
ASSERT(lock_fd >= 0);
|
||||
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
|
||||
}
|
||||
|
||||
void lock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_WRLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "LOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
void unlock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_UNLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "UNLOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(set_lock_path)
|
||||
void set_lock_path(string path);
|
||||
|
||||
void lock();
|
||||
|
||||
void unlock();
|
||||
|
||||
struct lock_guard {
|
||||
inline lock_guard() { lock(); }
|
||||
inline ~lock_guard() { unlock(); }
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -14,6 +14,7 @@
|
|||
#include "misc/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "lock.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -672,7 +673,7 @@ string OpCompiler::get_jit_src(Op* op) {
|
|||
else
|
||||
after_include_src += src;
|
||||
}
|
||||
ASSERT(file_exist(src_path));
|
||||
ASSERT(file_exist(src_path)) << src_path;
|
||||
LOGvvv << "Read from" << src_path;
|
||||
string src = read_all(src_path);
|
||||
ASSERT(src.size()) << "Source read failed:" << src_path;
|
||||
|
@ -945,6 +946,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
|
|||
}
|
||||
|
||||
jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
||||
jittor::lock_guard lg;
|
||||
OpCompiler oc(op);
|
||||
string* src = &oc.src;
|
||||
string src_after_passes;
|
||||
|
@ -954,8 +956,8 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
|||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
return oc.compile(op->get_jit_key(), *src);
|
||||
auto ret = oc.compile(op->get_jit_key(), *src);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue