This commit is contained in:
Dun Liang 2020-04-11 18:08:16 +08:00
commit d2ae3c05ff
18 changed files with 263 additions and 44 deletions

View File

@ -16,6 +16,8 @@ Jittor前端语言为Python。前端使用了模块化的设计这是目前
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")

View File

@ -16,6 +16,8 @@ The following example shows how to model a two-layer neural network step by step
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")

View File

@ -21,6 +21,8 @@ The following example shows how to model a two-layer neural network step by step
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -38,13 +40,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")

View File

@ -7,14 +7,16 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from . import compiler
from .compiler import LOG, has_cuda
from .compiler import compile_custom_ops, compile_custom_op
import jittor_core as core
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
from . import lock
with lock.lock_scope():
from . import compiler
from .compiler import LOG, has_cuda
from .compiler import compile_custom_ops, compile_custom_op
import jittor_core as core
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
import contextlib
import numpy as np

View File

@ -17,6 +17,7 @@ from ctypes.util import find_library
import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler
from . import lock
def find_jittor_path():
return os.path.dirname(__file__)
@ -518,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
"""
return jit_src
@lock.lock_scope()
def compile_custom_op(header, source, op_name, warp=True):
"""Compile a single custom op
header: code of op header, not path
@ -554,6 +556,7 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
@lock.lock_scope()
def compile_custom_ops(
filenames,
extra_flags="",
@ -643,8 +646,10 @@ def compile_custom_ops(
lib_path = os.path.join(cache_path, "custom_ops")
if lib_path not in os.sys.path:
os.sys.path.append(lib_path)
with jit_utils.import_scope(dlopen_flags):
exec(f"import {gen_name}")
# unlock scope when initialize
with lock.unlock_scope():
with jit_utils.import_scope(dlopen_flags):
exec(f"import {gen_name}")
mod = locals()[gen_name]
if return_module:
return mod
@ -800,6 +805,7 @@ dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
jittor_path = find_jittor_path()
check_debug_flags()
@ -966,3 +972,5 @@ flags.jittor_path = jittor_path
flags.gdb_path = gdb_path
flags.addr2line_path = addr2line_path
flags.has_pybt = has_pybt
core.set_lock_path(lock.lock_path)

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
import numpy as np
from collections.abc import Sequence, Mapping
from PIL import Image
from .. import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
@ -36,7 +37,7 @@ def _progress():
return bar_update
@lock.lock_scope()
def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename)

63
python/jittor/lock.py Normal file
View File

@ -0,0 +1,63 @@
import fcntl
import os
from jittor_utils import cache_path, LOG
class Lock:
def __init__(self, filename):
self.handle = open(filename, 'w')
LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
self.is_locked = False
def lock(self):
fcntl.flock(self.handle, fcntl.LOCK_EX)
self.is_locked = True
LOG.vv(f'LOCK PID: {os.getpid()}')
def unlock(self):
fcntl.flock(self.handle, fcntl.LOCK_UN)
self.is_locked = False
LOG.vv(f'UNLOCK PID: {os.getpid()}')
def __del__(self):
self.handle.close()
class _base_scope:
'''base_scope for support @xxx syntax'''
def __enter__(self): pass
def __exit__(self, *exc): pass
def __call__(self, func):
def inner(*args, **kw):
with self:
ret = func(*args, **kw)
return ret
return inner
class lock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if not self.is_locked:
jittor_lock.lock()
def __exit__(self, *exc):
if not self.is_locked:
jittor_lock.unlock()
class unlock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if self.is_locked:
jittor_lock.unlock()
def __exit__(self, *exc):
if self.is_locked:
jittor_lock.lock()
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
if not os.path.exists(lock_path):
LOG.i("Create lock file:", lock_path)
try:
os.mknod(lock_path)
except:
pass
jittor_lock = Lock(lock_path)

View File

@ -5,23 +5,33 @@
# ***************************************************************
if __name__ == "__main__":
import unittest
import unittest, os
suffix = "__main__.py"
assert __file__.endswith(suffix)
test_dir = __file__[:-len(suffix)]
import os
skip_l = int(os.environ.get("test_skip_l", "0"))
skip_r = int(os.environ.get("test_skip_r", "1000000"))
test_only = None
if "test_only" in os.environ:
test_only = set(os.environ.get("test_only").split(","))
test_files = os.listdir(test_dir)
for test_file in test_files:
test_files = sorted(test_files)
suite = unittest.TestSuite()
for _, test_file in enumerate(test_files):
if not test_file.startswith("test_"):
continue
if _ < skip_l or _ > skip_r:
continue
test_name = test_file.split(".")[0]
exec(f"from . import {test_name}")
test_mod = globals()[test_name]
print(test_name)
for i in dir(test_mod):
obj = getattr(test_mod, i)
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
globals()[test_name+"_"+i] = obj
if test_only and test_name not in test_only:
continue
unittest.main()
print("Add Test", _, test_name)
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
"jittor.test."+test_name))
unittest.TextTestRunner(verbosity=3).run(suite)

View File

@ -18,6 +18,7 @@ import pickle as pk
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import MaxPool2d, Sequential
except:

View File

@ -11,6 +11,7 @@ import numpy as np
class TestClone(unittest.TestCase):
def test(self):
jt.clean()
b = a = jt.array(1)
for i in range(10):
b = b.clone()

View File

@ -18,11 +18,12 @@ def test_cuda(use_cuda=1):
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestCuda(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test_cuda_flags(self):
with jt.var_scope(use_cuda=1):
a = jt.random((10, 10))
a.sync()
a = jt.random((10, 10))
a.sync()
@jt.flag_scope(use_cuda=2)
def test_no_cuda_op(self):
no_cuda_op = jt.compile_custom_op("""
struct NoCudaOp : Op {
@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase):
""",
"no_cuda")
# force use cuda
with jt.var_scope(use_cuda=2):
a = no_cuda_op([3,4,5], 'float')
expect_error(lambda: a())
a = no_cuda_op([3,4,5], 'float')
expect_error(lambda: a())
@jt.flag_scope(use_cuda=1)
def test_cuda_custom_op(self):
my_op = jt.compile_custom_op("""
struct MyCudaOp : Op {
@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase):
#endif // JIT
""",
"my_cuda")
with jt.var_scope(use_cuda=1):
a = my_op([3,4,5], 'float')
na = a.data
a = my_op([3,4,5], 'float')
na = a.data
assert a.shape == [3,4,5] and a.dtype == 'float'
assert (-na.flatten() == range(3*4*5)).all(), na

View File

@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test(self):
t = cutt_ops.cutt_test("213")
jt.sync_all(True)
print(t.data)
assert t.data == 123
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,27 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
from pathlib import Path
class TestLock(unittest.TestCase):
def test(self):
if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
assert os.system(f"rm -rf {cache_path}") == 0
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
else:
cmd = f"{sys.executable} -m jittor.test.test_example"
print("run cmd twice", cmd)
assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0
if __name__ == "__main__":
unittest.main()

View File

@ -11,6 +11,7 @@ import numpy as np
class TestMiscIssue(unittest.TestCase):
def test_issue4(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return
@ -42,6 +43,7 @@ b.sync()
def test_mkl_conflict1(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return
@ -67,6 +69,7 @@ m(torch.rand(*nchw))
def test_mkl_conflict2(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return

View File

@ -158,9 +158,10 @@ def find_cache_path():
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE)
assert r.returncode == 0
bs = r.stdout.decode()
bs = r.stdout.decode().splitlines()
for b in bs:
if b.startswith("* "): break
cache_name = b[2:]
for c in " (){}": cache_name = cache_name.replace(c, "_")
except:
@ -168,10 +169,14 @@ def find_cache_path():
for name in cache_name.split("/"):
dirs.insert(-1, name)
os.environ["cache_name"] = cache_name
LOG.v("cache_name", cache_name)
for d in dirs:
path = os.path.join(path, d)
if not os.path.isdir(path):
os.mkdir(path)
try:
os.mkdir(path)
except:
pass
assert os.path.isdir(path)
if path not in sys.path:
sys.path.append(path)

50
src/lock.cc Normal file
View File

@ -0,0 +1,50 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Wenyang Zhou <576825820@qq.com>
// Dun Liang <randonlang@gmail.com>
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include "lock.h"
namespace jittor {
static int lock_fd = -1;
void set_lock_path(string path) {
lock_fd = open(path.c_str(), O_RDWR);
ASSERT(lock_fd >= 0);
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
}
void lock() {
ASSERT(lock_fd >= 0);
struct flock lock = {
.l_type = F_WRLCK,
.l_whence = SEEK_SET,
.l_start = 0,
.l_len = 0
};
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
LOGvv << "LOCK Pid:" << getpid();
}
void unlock() {
ASSERT(lock_fd >= 0);
struct flock lock = {
.l_type = F_UNLCK,
.l_whence = SEEK_SET,
.l_start = 0,
.l_len = 0
};
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
LOGvv << "UNLOCK Pid:" << getpid();
}
} // jittor

26
src/lock.h Normal file
View File

@ -0,0 +1,26 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Wenyang Zhou <576825820@qq.com>
// Dun Liang <randonlang@gmail.com>
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
// @pyjt(set_lock_path)
void set_lock_path(string path);
void lock();
void unlock();
struct lock_guard {
inline lock_guard() { lock(); }
inline ~lock_guard() { unlock(); }
};
} // jittor

View File

@ -14,6 +14,7 @@
#include "misc/str_utils.h"
#include "ops/op_register.h"
#include "ops/array_op.h"
#include "lock.h"
namespace jittor {
@ -672,7 +673,7 @@ string OpCompiler::get_jit_src(Op* op) {
else
after_include_src += src;
}
ASSERT(file_exist(src_path));
ASSERT(file_exist(src_path)) << src_path;
LOGvvv << "Read from" << src_path;
string src = read_all(src_path);
ASSERT(src.size()) << "Source read failed:" << src_path;
@ -945,6 +946,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
}
jit_op_entry_t OpCompiler::do_compile(Op* op) {
jittor::lock_guard lg;
OpCompiler oc(op);
string* src = &oc.src;
string src_after_passes;
@ -954,8 +956,8 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src_after_passes = tm.tune();
src = &src_after_passes;
}
return oc.compile(op->get_jit_key(), *src);
auto ret = oc.compile(op->get_jit_key(), *src);
return ret;
}
}