Merge branch 'master' into ygy

This commit is contained in:
cxjyxx_me 2020-04-15 16:44:17 +08:00
commit fb726ad22a
27 changed files with 478 additions and 190 deletions

View File

@ -16,6 +16,8 @@ Jittor前端语言为Python。前端使用了模块化的设计这是目前
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")
@ -266,7 +273,7 @@ print(type(a), type(b), type(c))
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
```python
c.max() # alias of jt.max(a)
c.max() # alias of jt.max(c)
c.add(a) # alias of jt.add(c, a)
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
```

View File

@ -16,6 +16,8 @@ The following example shows how to model a two-layer neural network step by step
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")
@ -261,7 +268,7 @@ Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(..
```python
c.max() # alias of jt.max(a)
c.max() # alias of jt.max(c)
c.add(a) # alias of jt.add(c, a)
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
```

View File

@ -21,6 +21,8 @@ The following example shows how to model a two-layer neural network step by step
import jittor as jt
from jittor import Module
from jittor import nn
import numpy as np
class Model(Module):
def __init__(self):
self.layer1 = nn.Linear(1, 10)
@ -38,13 +40,18 @@ def get_data(n): # generate random data for training test.
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model()
learning_rate = 0.1
batch_size = 50
n = 1000
model = Model()
optim = nn.SGD(model.parameters(), learning_rate)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x)
loss = ((pred_y - y)**2)
dy = pred_y - y
loss = dy * dy
loss_mean = loss.mean()
optim.step(loss_mean)
print(f"step {i}, loss = {loss_mean.data.sum()}")
@ -325,7 +332,7 @@ Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(..
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
```python
c.max() # alias of jt.max(a)
c.max() # alias of jt.max(c)
c.add(a) # alias of jt.add(c, a)
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
```

View File

@ -37,4 +37,9 @@ int _mpi_world_rank();
// @pyjt(local_rank)
int _mpi_local_rank();
struct ArrayArgs;
// @pyjt(broadcast)
void _mpi_broadcast(ArrayArgs&& args, int i);
} // jittor

View File

@ -12,6 +12,7 @@
#include "mpi_warper.h"
#include "common.h"
#include "ops/array_op.h"
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
@ -28,9 +29,9 @@ void throw_mpi_error(int result,
namespace jittor {
int mpi_world_size;
int mpi_world_rank;
int mpi_local_rank;
int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_rank = 0;
int _mpi_world_size() {
return mpi_world_size;
@ -44,7 +45,12 @@ int _mpi_local_rank() {
return mpi_local_rank;
}
void _mpi_broadcast(ArrayArgs&& args, int i) {
int64 size = args.dtype.dsize();
for (auto j : args.shape)
size *= j;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
}
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
@ -69,6 +75,7 @@ static void getHostName(char* hostname, int maxlen) {
struct mpi_initer {
mpi_initer() {
LOGvv << "MPI init...";
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
@ -84,6 +91,9 @@ mpi_initer() {
if (p == mpi_world_rank) break;
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
}
LOGv << "MPI init finished: local" << mpi_local_rank
<< "global" << mpi_world_rank
<< "size" << mpi_world_size;
}
~mpi_initer() {

View File

@ -7,14 +7,16 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from . import compiler
from .compiler import LOG, has_cuda
from .compiler import compile_custom_ops, compile_custom_op
import jittor_core as core
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
from . import lock
with lock.lock_scope():
from . import compiler
from .compiler import LOG, has_cuda
from .compiler import compile_custom_ops, compile_custom_op
import jittor_core as core
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
import contextlib
import numpy as np

View File

@ -5,7 +5,7 @@
# ***************************************************************
import os, sys, shutil
from .compiler import *
from jittor_utils import run_cmd
from jittor_utils import run_cmd, get_version
from jittor.dataset.utils import download_url_to_local
def search_file(dirs, name):
@ -322,6 +322,8 @@ def manual_link(flags):
ctypes.CDLL(libname, dlopen_flags)
break
def inside_mpi():
return "OMPI_COMM_WORLD_SIZE" in os.environ
def setup_mpi():
global mpi_ops, mpi, use_mpi
@ -330,7 +332,6 @@ def setup_mpi():
mpi_ops = None
mpi = None
has_mpi = False
if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
LOG.i("mpicc not found, distribution disabled.")
@ -338,6 +339,8 @@ def setup_mpi():
else:
use_mpi = True
has_mpi = True
if not inside_mpi():
use_mpi = False
if not use_mpi:
return
@ -345,8 +348,7 @@ def setup_mpi():
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
LOG.i("mpi_flags: "+mpi_flags)
manual_link(mpi_flags)
LOG.v("mpi_flags: "+mpi_flags)
# find all source files
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
@ -359,8 +361,15 @@ def setup_mpi():
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
mpi_version = get_version(mpicc_path)
if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
# mpi version 1.x need to link like this
manual_link(mpi_flags)
# mpi(4.x) cannot use deepbind, it need to
# share the 'environ' symbol.
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True)
extra_flags=f" {mpi_flags} ", return_module=True,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))

View File

@ -17,6 +17,7 @@ from ctypes.util import find_library
import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler
from . import lock
def find_jittor_path():
return os.path.dirname(__file__)
@ -518,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
"""
return jit_src
@lock.lock_scope()
def compile_custom_op(header, source, op_name, warp=True):
"""Compile a single custom op
header: code of op header, not path
@ -554,7 +556,12 @@ def compile_custom_op(header, source, op_name, warp=True):
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
def compile_custom_ops(filenames, extra_flags="", return_module=False):
@lock.lock_scope()
def compile_custom_ops(
filenames,
extra_flags="",
return_module=False,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
"""Compile custom ops
filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
@ -639,8 +646,10 @@ def compile_custom_ops(filenames, extra_flags="", return_module=False):
lib_path = os.path.join(cache_path, "custom_ops")
if lib_path not in os.sys.path:
os.sys.path.append(lib_path)
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
exec(f"import {gen_name}")
# unlock scope when initialize
with lock.unlock_scope():
with jit_utils.import_scope(dlopen_flags):
exec(f"import {gen_name}")
mod = locals()[gen_name]
if return_module:
return mod
@ -796,6 +805,7 @@ dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
jittor_path = find_jittor_path()
check_debug_flags()
@ -962,3 +972,5 @@ flags.jittor_path = jittor_path
flags.gdb_path = gdb_path
flags.addr2line_path = addr2line_path
flags.has_pybt = has_pybt
core.set_lock_path(lock.lock_path)

View File

@ -23,6 +23,7 @@ import jittor as jt
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.compile_extern.mpi
class Worker:
def __init__(self, target, args, buffer_size):
@ -69,9 +70,11 @@ class Dataset(object):
def __len__(self):
assert self.total_len >= 0
assert self.batch_size > 0
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
else self.total_len
if self.drop_last:
return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1
return real_len // self.batch_size
return (real_len-1) // self.batch_size + 1
def set_attrs(self, **kw):
'''set attributes of dataset, equivalent to setattr
@ -130,8 +133,8 @@ class Dataset(object):
self.gidc.notify()
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
batch.append(self[self.index_list[i]])
batch = self.collate_batch(batch)
if mp_log_v:
@ -157,7 +160,7 @@ class Dataset(object):
w.buffer.clear()
def _init_workers(self):
self.index_list = mp.Array('i', self.total_len, lock=False)
self.index_list = mp.Array('i', self.real_len, lock=False)
workers = []
# batch id to worker id
self.idmap = mp.Array('i', self.batch_len, lock=False)
@ -174,7 +177,7 @@ class Dataset(object):
buffer_size=self.buffer_size)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
def __del__(self):
if mp_log_v:
@ -186,6 +189,25 @@ class Dataset(object):
index_list = get_order_list(self.total_len)
else:
index_list = get_random_list(self.total_len)
# scatter index_list for all mpi process
# scatter rule:
# [........]
# 000111
# 222
# make sure each process has the same len
if mpi:
index_list = np.int32(index_list)
mpi.broadcast(index_list, 0)
real_len = (self.total_len - 1) // mpi.world_size() + 1
offset = mpi.world_rank() * real_len
if offset + real_len > self.total_len:
offset -= offset + real_len - self.total_len
index_list = index_list[offset:offset+real_len]
self.real_len = real_len
assert real_len == len(index_list)
else:
self.real_len = self.total_len
self.batch_len = len(self)
if "batch_len" in os.environ:

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
import numpy as np
from collections.abc import Sequence, Mapping
from PIL import Image
from .. import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
@ -36,7 +37,7 @@ def _progress():
return bar_update
@lock.lock_scope()
def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename)

63
python/jittor/lock.py Normal file
View File

@ -0,0 +1,63 @@
import fcntl
import os
from jittor_utils import cache_path, LOG
class Lock:
def __init__(self, filename):
self.handle = open(filename, 'w')
LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
self.is_locked = False
def lock(self):
fcntl.flock(self.handle, fcntl.LOCK_EX)
self.is_locked = True
LOG.vv(f'LOCK PID: {os.getpid()}')
def unlock(self):
fcntl.flock(self.handle, fcntl.LOCK_UN)
self.is_locked = False
LOG.vv(f'UNLOCK PID: {os.getpid()}')
def __del__(self):
self.handle.close()
class _base_scope:
'''base_scope for support @xxx syntax'''
def __enter__(self): pass
def __exit__(self, *exc): pass
def __call__(self, func):
def inner(*args, **kw):
with self:
ret = func(*args, **kw)
return ret
return inner
class lock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if not self.is_locked:
jittor_lock.lock()
def __exit__(self, *exc):
if not self.is_locked:
jittor_lock.unlock()
class unlock_scope(_base_scope):
def __enter__(self):
self.is_locked = jittor_lock.is_locked
if self.is_locked:
jittor_lock.unlock()
def __exit__(self, *exc):
if self.is_locked:
jittor_lock.lock()
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
if not os.path.exists(lock_path):
LOG.i("Create lock file:", lock_path)
try:
os.mknod(lock_path)
except:
pass
jittor_lock = Lock(lock_path)

View File

@ -5,23 +5,33 @@
# ***************************************************************
if __name__ == "__main__":
import unittest
import unittest, os
suffix = "__main__.py"
assert __file__.endswith(suffix)
test_dir = __file__[:-len(suffix)]
import os
skip_l = int(os.environ.get("test_skip_l", "0"))
skip_r = int(os.environ.get("test_skip_r", "1000000"))
test_only = None
if "test_only" in os.environ:
test_only = set(os.environ.get("test_only").split(","))
test_files = os.listdir(test_dir)
for test_file in test_files:
test_files = sorted(test_files)
suite = unittest.TestSuite()
for _, test_file in enumerate(test_files):
if not test_file.startswith("test_"):
continue
if _ < skip_l or _ > skip_r:
continue
test_name = test_file.split(".")[0]
exec(f"from . import {test_name}")
test_mod = globals()[test_name]
print(test_name)
for i in dir(test_mod):
obj = getattr(test_mod, i)
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
globals()[test_name+"_"+i] = obj
if test_only and test_name not in test_only:
continue
unittest.main()
print("Add Test", _, test_name)
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
"jittor.test."+test_name))
unittest.TextTestRunner(verbosity=3).run(suite)

View File

@ -18,6 +18,7 @@ import pickle as pk
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
from torch.nn import MaxPool2d, Sequential
except:
@ -45,7 +46,7 @@ def check(jt_model, torch_model, shape, near_data):
@unittest.skipIf(skip_this_test, "No Torch found")
class TestArgPoolOp(unittest.TestCase):
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
@ -58,15 +59,18 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_(self):
x = jt.random([32, 128, 157, 300])
# x = jt.random([32, 128, 157, 300])
x = jt.random([4, 128, 157, 300])
x = jt.nn.pool(x, 2, "maximum", 0, 2)
def test_cpu(self):
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
shape = [64, 64, 300, 300]
# shape = [64, 64, 300, 300]
shape = [4, 64, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [32, 128, 157, 300]
# shape = [32, 128, 157, 300]
shape = [4, 128, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)

View File

@ -11,6 +11,7 @@ import numpy as np
class TestClone(unittest.TestCase):
def test(self):
jt.clean()
b = a = jt.array(1)
for i in range(10):
b = b.clone()

View File

@ -18,11 +18,12 @@ def test_cuda(use_cuda=1):
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestCuda(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test_cuda_flags(self):
with jt.var_scope(use_cuda=1):
a = jt.random((10, 10))
a.sync()
a = jt.random((10, 10))
a.sync()
@jt.flag_scope(use_cuda=2)
def test_no_cuda_op(self):
no_cuda_op = jt.compile_custom_op("""
struct NoCudaOp : Op {
@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase):
""",
"no_cuda")
# force use cuda
with jt.var_scope(use_cuda=2):
a = no_cuda_op([3,4,5], 'float')
expect_error(lambda: a())
a = no_cuda_op([3,4,5], 'float')
expect_error(lambda: a())
@jt.flag_scope(use_cuda=1)
def test_cuda_custom_op(self):
my_op = jt.compile_custom_op("""
struct MyCudaOp : Op {
@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase):
#endif // JIT
""",
"my_cuda")
with jt.var_scope(use_cuda=1):
a = my_op([3,4,5], 'float')
na = a.data
a = my_op([3,4,5], 'float')
na = a.data
assert a.shape == [3,4,5] and a.dtype == 'float'
assert (-na.flatten() == range(3*4*5)).all(), na

View File

@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test(self):
t = cutt_ops.cutt_test("213")
jt.sync_all(True)
print(t.data)
assert t.data == 123
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,27 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
from pathlib import Path
class TestLock(unittest.TestCase):
def test(self):
if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
assert os.system(f"rm -rf {cache_path}") == 0
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
else:
cmd = f"{sys.executable} -m jittor.test.test_example"
print("run cmd twice", cmd)
assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0
if __name__ == "__main__":
unittest.main()

View File

@ -11,6 +11,7 @@ import numpy as np
class TestMiscIssue(unittest.TestCase):
def test_issue4(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return
@ -42,6 +43,7 @@ b.sync()
def test_mkl_conflict1(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return
@ -67,6 +69,7 @@ m(torch.rand(*nchw))
def test_mkl_conflict2(self):
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
return
@ -126,5 +129,13 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
assert a.min().data == a.data.min(), (a.min(), a.data.min())
assert a.max().data == a.data.max(), (a.max(), a.data.max())
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cuda_pow_grad_nan(self):
a = jt.float32([1,-1, -1000.1])
da = jt.grad(a**2, a)
assert np.isnan(da.data).sum()==0, da.data
if __name__ == "__main__":
unittest.main()

View File

@ -10,26 +10,51 @@ import unittest
import os, sys
import jittor as jt
import numpy as np
mpi = jt.compile_extern.mpi
def main():
print("test mpi_test")
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
if jt.compile_extern.nccl_ops:
print("test test_with_mpi")
with jt.flag_scope(use_cuda=1):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpi(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1:
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
def test_mpi_test_op(self):
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
@jt.flag_scope(use_cuda=1)
def test_nccl_with_mpi(self):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
def test_mpi_broadcast(self):
for i in range(mpi.world_size()):
a = np.zeros(100) + mpi.world_rank()
mpi.broadcast(a, i)
assert (a == i).all()
def test_mpi_dataset(self):
from jittor.dataset.dataset import Dataset
class ToyDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=1024)
def __getitem__(self, index):
return index, index*index
toy = ToyDataset()
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
for _ in range(2):
for i,(a,b) in enumerate(toy):
assert (a.data*a.data == b.data).all()
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiEntry(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
if __name__ == "__main__":
unittest.main()

View File

@ -133,5 +133,34 @@ class TestOpCompiler(unittest.TestCase):
expect_error(lambda: jit_precompile(vars, "@if(1)"))
expect_error(lambda: jit_precompile(vars, "#define OP1(a,b) a+b\n@expand_macro(OP1,1)"))
def test_strcmp(self):
vars = {"Tx":"float"}
check = lambda expr, result: \
self.assertEqual(jit_precompile(vars, expr), result)
check("@strcmp(aaa,aaa)", "0")
check("@strcmp(aaa,bbb)", "-1")
check("@strcmp(ccc,bbb)", "1")
check("@{@strcmp(aaa,aaa)}", "0")
check("@{@strcmp(aaa,bbb)}", "-1")
check("@{@strcmp(ccc,bbb)}", "1")
code = \
"""@define(T_NCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
ncclBcast(..., @T_NCCL, ...)
"""
assert "ncclFloat" in jit_precompile({"Tx":"float"}, code)
assert "ncclFloat" in jit_precompile({"Tx":"float32"}, code)
assert "ncclFloat64" in jit_precompile({"Tx":"float64"}, code)
assert "ncclInt" in jit_precompile({"Tx":"int"}, code)
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
if __name__ == "__main__":
unittest.main()

View File

@ -158,9 +158,10 @@ def find_cache_path():
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE)
assert r.returncode == 0
bs = r.stdout.decode()
bs = r.stdout.decode().splitlines()
for b in bs:
if b.startswith("* "): break
cache_name = b[2:]
for c in " (){}": cache_name = cache_name.replace(c, "_")
except:
@ -168,10 +169,14 @@ def find_cache_path():
for name in cache_name.split("/"):
dirs.insert(-1, name)
os.environ["cache_name"] = cache_name
LOG.v("cache_name", cache_name)
for d in dirs:
path = os.path.join(path, d)
if not os.path.isdir(path):
os.mkdir(path)
try:
os.mkdir(path)
except:
pass
assert os.path.isdir(path)
if path not in sys.path:
sys.path.append(path)

View File

@ -54,17 +54,22 @@ struct EventQueue {
static void worker_caller();
void run_sync(Func func) {
// send work to worker and do something by self
std::unique_lock<std::mutex> l(mtx);
this->func = func;
run_sync_done = false;
// send func to worker
worker.run(worker_caller);
while (1) {
// check self work or worker's status
cv.wait(l);
list<Func> ts = move(tasks);
l.unlock();
// do self works
for (auto func : ts)
func();
l.lock();
// worker is finished
if (run_sync_done)
return;
}

50
src/lock.cc Normal file
View File

@ -0,0 +1,50 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Wenyang Zhou <576825820@qq.com>
// Dun Liang <randonlang@gmail.com>
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include "lock.h"
namespace jittor {
static int lock_fd = -1;
void set_lock_path(string path) {
lock_fd = open(path.c_str(), O_RDWR);
ASSERT(lock_fd >= 0);
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
}
void lock() {
ASSERT(lock_fd >= 0);
struct flock lock = {
.l_type = F_WRLCK,
.l_whence = SEEK_SET,
.l_start = 0,
.l_len = 0
};
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
LOGvv << "LOCK Pid:" << getpid();
}
void unlock() {
ASSERT(lock_fd >= 0);
struct flock lock = {
.l_type = F_UNLCK,
.l_whence = SEEK_SET,
.l_start = 0,
.l_len = 0
};
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
LOGvv << "UNLOCK Pid:" << getpid();
}
} // jittor

26
src/lock.h Normal file
View File

@ -0,0 +1,26 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Wenyang Zhou <576825820@qq.com>
// Dun Liang <randonlang@gmail.com>
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
// @pyjt(set_lock_path)
void set_lock_path(string path);
void lock();
void unlock();
struct lock_guard {
inline lock_guard() { lock(); }
inline ~lock_guard() { unlock(); }
};
} // jittor

View File

@ -14,6 +14,8 @@
#include "misc/str_utils.h"
#include "ops/op_register.h"
#include "ops/array_op.h"
#include "lock.h"
#include "opt/expr.h"
namespace jittor {
@ -103,48 +105,6 @@ int OpCompiler::total_member_count() {
return member_count;
}
#define FOR_ALL_UOPS(m) \
m(!,3) m(~,3)
#define FOR_ALL_BOPS(m) \
m(*,5) m(/,5) m(%,5) \
m(+,6) m(-,6) \
m(<<,7) m(>>,7) \
m(<,9) m(<=,9) m(>,9) m(>=,9) \
m(!=,10) m(==,10) \
m(&,11) \
m(^,12) \
m(|,13) \
m(&&,14) \
m(||,15)
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
inline bool is_unary_op(const string& op) {
#define _u(o, _) if (op == #o) return true;
FOR_ALL_UOPS(_u);
return false;
}
inline int precedence(const string& op) {
#define _prior(o, p) if (op == #o) return p;
FOR_ALL_OPS(_prior);
return 20;
}
inline bool check_precedence(const string& op1, const string& op2) {
if (op1 == op2 && is_unary_op(op1)) return false;
return precedence(op1) <= precedence(op2);
}
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
#define _calc_b(o, _) if (op == #o) return a o b;
FOR_ALL_BOPS(_calc_b);
#define _calc_u(o, _) if (op == #o) return o b;
FOR_ALL_UOPS(_calc_u);
ASSERT(0) << "Unrecognized op " << op;
return 0;
}
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
if (expr.find("@") != string::npos) {
string new_expr;
@ -174,6 +134,22 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
ASSERT(isvar(expr[j]));
size_t k=j+1;
while (k<expr.size() && isvar(expr[k])) k++;
if (k<expr.size() && expr[k]=='(') {
// syntax @xx(...)
// ij k l
size_t l=k+1;
int presum = 1;
while (l<expr.size() && presum) {
if (expr[l] == ')')
presum--;
else if (expr[l] == '(')
presum++;
l++;
}
new_expr += precompile(vars, expr.substr(i, l-i));
i = l-1;
continue;
}
string var = expr.substr(j, k-j);
auto iter = vars.find(var);
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found." << vars;
@ -184,68 +160,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
}
return eval(new_expr, vars);
}
vector<int64> values = {0};
vector<string> ops;
auto pop_values_and_calc_op = [&]() {
CHECK(ops.size());
auto op = ops.back();
ops.pop_back();
CHECK(values.size());
auto val2 = values.back();
values.pop_back();
auto val1 = val2;
if (!is_unary_op(op)) {
CHECK(values.size());
val1 = values.back();
values.pop_back();
auto e = expr::make(expr);
e->dfs([&](expr::Expr* s) {
if (s->is_sym()) {
auto iter = vars.find(s->str);
ASSERT(iter!=vars.end()) << "Jit var " << s->str << " not found.";
auto e = expr::make(iter->second);
s->swap(e.get());
}
values.push_back(calc_op(val1, val2, op));
};
for (size_t i=0; i<expr.size(); i++) {
if (expr[i] == ' ')
continue;
if (expr[i] == '(')
ops.push_back(string()+expr[i]);
else if (isdigit(expr[i])) {
int64_t val = 0;
while (i<expr.length() && isdigit(expr[i])) {
val = val*10 + (expr[i]-'0');
i++;
}
i--;
values.push_back(val);
} else if (isvar(expr[i])) {
auto j=i+1;
while (j<expr.size() && isvar(expr[j])) j++;
auto var_name = expr.substr(i,j-i);
auto iter = vars.find(var_name);
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
try {
values.push_back(std::stoll(iter->second));
} catch (...) {
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
}
i = j-1;
} else if (expr[i] == ')') {
while (ops.size() && ops.back() != "(")
pop_values_and_calc_op();
ops.pop_back();
} else {
auto j=i+1;
while (j<expr.size() && expr[j] != ' ' &&
expr[j] != '!' && expr[j] != '~' &&
!isdigit(expr[j]) && !isvar(expr[j]) &&
expr[j] != '(' && expr[j] != ')') j++;
auto op = expr.substr(i, j-i);
while (ops.size() && check_precedence(ops.back(), op))
pop_values_and_calc_op();
ops.push_back(op);
i = j-1;
}
}
while (ops.size())
pop_values_and_calc_op();
return values.back();
});
e = e->eval();
ASSERT(e->is(expr::_int));
return e->as_int();
}
void load_macros(const string& src, unordered_map<string,string>& macros) {
@ -587,6 +513,19 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
i = l-1;
continue;
} else
if (expr == "strcmp") {
// syntax: @strcmp(s1,s2)
// ij k l
ASSERT(args.size()==2u)
<< "Jit error: strcmp wrong arguments.";
auto s1 = precompile(defs, args[0], macros);
auto s2 = precompile(defs, args[1], macros);
if (s1<s2) new_src += "-1"; else
if (s1==s2) new_src += "0"; else
new_src += "1";
i = l-1;
continue;
} else
if (args.size()) {
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
int nid=(int)expr.size();
@ -672,7 +611,7 @@ string OpCompiler::get_jit_src(Op* op) {
else
after_include_src += src;
}
ASSERT(file_exist(src_path));
ASSERT(file_exist(src_path)) << src_path;
LOGvvv << "Read from" << src_path;
string src = read_all(src_path);
ASSERT(src.size()) << "Source read failed:" << src_path;
@ -945,6 +884,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
}
jit_op_entry_t OpCompiler::do_compile(Op* op) {
jittor::lock_guard lg;
OpCompiler oc(op);
string* src = &oc.src;
string src_after_passes;
@ -954,8 +894,8 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src_after_passes = tm.tune();
src = &src_after_passes;
}
return oc.compile(op->get_jit_key(), *src);
auto ret = oc.compile(op->get_jit_key(), *src);
return ret;
}
}

View File

@ -13,6 +13,8 @@
namespace jittor {
#ifndef JIT
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_broadcast_to = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
static auto make_binary = get_op_info("binary")
@ -122,7 +124,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index == 0) {
// dout * y * x^(y-1)
auto d = make_binary(dout, y, ns_multiply);
auto ones = make_number(1, dout);
// auto ones = make_number(1, dout);
int number = 1;
auto ones = make_array(&number, 1, ns_int32);
auto y_1 = make_binary(y, ones, ns_subtract);
auto x_y_1 = make_binary(x, y_1, ns_pow);
return make_binary(d, x_y_1, ns_multiply);

View File

@ -235,10 +235,12 @@ void ConvTuner::forwardTune(FusedOp* fop) {
LOGvvvv << "Expr not match" << src_h << expr_h;
continue;
}
LOGvvvv << "H Expr matched" << src_h << expr_h;
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
auto src_w = expr::make(riop1->indexes[xw]);
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
return;
LOGvvvv << "W Expr matched" << src_w << expr_w;
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
int stride_h = rh[0]->as_int();
int padding_h = -rh[1]->as_int();
@ -253,7 +255,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
}
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
if (xformat == "bacd" && dilation_h != 1) {
LOGvvvv << "mkl not support bacd dilation, continue";
continue;
}
int stride = stride_h;
int padding = padding_h;
int dilation = dilation_h;
@ -363,6 +368,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
x = riop1->x;
y = riop2->x;
bo++;
LOGvvvv << "backward_w get stride padding and dilation" << stride << padding << dilation;
} else if (op->name_ex() == "reindex_reduce.add") {
auto rop = (ReindexReduceOp*)op;
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
@ -438,6 +444,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
w = riop1->x;
y = riop2->x;
bo+=2;
LOGvvvv << "backward_x get stride padding and dilation" << stride << padding << dilation;
}
// TODO: CUDA only support nchw(abcd)