mirror of https://github.com/Jittor/Jittor
Merge branch 'master' into ygy
This commit is contained in:
commit
fb726ad22a
13
README.cn.md
13
README.cn.md
|
@ -16,6 +16,8 @@ Jittor前端语言为Python。前端使用了模块化的设计,这是目前
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
@ -266,7 +273,7 @@ print(type(a), type(b), type(c))
|
|||
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
|
||||
|
||||
```python
|
||||
c.max() # alias of jt.max(a)
|
||||
c.max() # alias of jt.max(c)
|
||||
c.add(a) # alias of jt.add(c, a)
|
||||
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
|
||||
```
|
||||
|
|
13
README.md
13
README.md
|
@ -16,6 +16,8 @@ The following example shows how to model a two-layer neural network step by step
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -33,13 +35,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
@ -261,7 +268,7 @@ Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(..
|
|||
|
||||
|
||||
```python
|
||||
c.max() # alias of jt.max(a)
|
||||
c.max() # alias of jt.max(c)
|
||||
c.add(a) # alias of jt.add(c, a)
|
||||
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
|
||||
```
|
||||
|
|
|
@ -21,6 +21,8 @@ The following example shows how to model a two-layer neural network step by step
|
|||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.layer1 = nn.Linear(1, 10)
|
||||
|
@ -38,13 +40,18 @@ def get_data(n): # generate random data for training test.
|
|||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model()
|
||||
|
||||
learning_rate = 0.1
|
||||
batch_size = 50
|
||||
n = 1000
|
||||
|
||||
model = Model()
|
||||
optim = nn.SGD(model.parameters(), learning_rate)
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x)
|
||||
loss = ((pred_y - y)**2)
|
||||
dy = pred_y - y
|
||||
loss = dy * dy
|
||||
loss_mean = loss.mean()
|
||||
optim.step(loss_mean)
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()}")
|
||||
|
@ -325,7 +332,7 @@ Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(..
|
|||
除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如:
|
||||
|
||||
```python
|
||||
c.max() # alias of jt.max(a)
|
||||
c.max() # alias of jt.max(c)
|
||||
c.add(a) # alias of jt.add(c, a)
|
||||
c.min(keepdims=True) # alias of jt.min(c, keepdims=True)
|
||||
```
|
||||
|
|
|
@ -37,4 +37,9 @@ int _mpi_world_rank();
|
|||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
struct ArrayArgs;
|
||||
|
||||
// @pyjt(broadcast)
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mpi_warper.h"
|
||||
#include "common.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||
|
||||
|
@ -28,9 +29,9 @@ void throw_mpi_error(int result,
|
|||
namespace jittor {
|
||||
|
||||
|
||||
int mpi_world_size;
|
||||
int mpi_world_rank;
|
||||
int mpi_local_rank;
|
||||
int mpi_world_size = 1;
|
||||
int mpi_world_rank = 0;
|
||||
int mpi_local_rank = 0;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
|
@ -44,7 +45,12 @@ int _mpi_local_rank() {
|
|||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto j : args.shape)
|
||||
size *= j;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
|
@ -69,6 +75,7 @@ static void getHostName(char* hostname, int maxlen) {
|
|||
struct mpi_initer {
|
||||
|
||||
mpi_initer() {
|
||||
LOGvv << "MPI init...";
|
||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
@ -84,6 +91,9 @@ mpi_initer() {
|
|||
if (p == mpi_world_rank) break;
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
|
||||
}
|
||||
LOGv << "MPI init finished: local" << mpi_local_rank
|
||||
<< "global" << mpi_world_rank
|
||||
<< "size" << mpi_world_size;
|
||||
}
|
||||
|
||||
~mpi_initer() {
|
||||
|
|
|
@ -7,14 +7,16 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
import jittor_core as core
|
||||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# ***************************************************************
|
||||
import os, sys, shutil
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd
|
||||
from jittor_utils import run_cmd, get_version
|
||||
from jittor.dataset.utils import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
|
@ -322,6 +322,8 @@ def manual_link(flags):
|
|||
ctypes.CDLL(libname, dlopen_flags)
|
||||
break
|
||||
|
||||
def inside_mpi():
|
||||
return "OMPI_COMM_WORLD_SIZE" in os.environ
|
||||
|
||||
def setup_mpi():
|
||||
global mpi_ops, mpi, use_mpi
|
||||
|
@ -330,7 +332,6 @@ def setup_mpi():
|
|||
mpi_ops = None
|
||||
mpi = None
|
||||
has_mpi = False
|
||||
if not use_mpi: return
|
||||
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
|
||||
if mpicc_path == "":
|
||||
LOG.i("mpicc not found, distribution disabled.")
|
||||
|
@ -338,6 +339,8 @@ def setup_mpi():
|
|||
else:
|
||||
use_mpi = True
|
||||
has_mpi = True
|
||||
if not inside_mpi():
|
||||
use_mpi = False
|
||||
if not use_mpi:
|
||||
return
|
||||
|
||||
|
@ -345,8 +348,7 @@ def setup_mpi():
|
|||
mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile")
|
||||
mpi_link_flags = run_cmd(mpicc_path+" --showme:link")
|
||||
mpi_flags = mpi_compile_flags + " " + mpi_link_flags
|
||||
LOG.i("mpi_flags: "+mpi_flags)
|
||||
manual_link(mpi_flags)
|
||||
LOG.v("mpi_flags: "+mpi_flags)
|
||||
|
||||
# find all source files
|
||||
mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
|
||||
|
@ -359,8 +361,15 @@ def setup_mpi():
|
|||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
mpi_version = get_version(mpicc_path)
|
||||
if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
|
||||
# mpi version 1.x need to link like this
|
||||
manual_link(mpi_flags)
|
||||
# mpi(4.x) cannot use deepbind, it need to
|
||||
# share the 'environ' symbol.
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||
extra_flags=f" {mpi_flags} ", return_module=True,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
|
|
@ -17,6 +17,7 @@ from ctypes.util import find_library
|
|||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
from . import lock
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
@ -518,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
"""
|
||||
return jit_src
|
||||
|
||||
@lock.lock_scope()
|
||||
def compile_custom_op(header, source, op_name, warp=True):
|
||||
"""Compile a single custom op
|
||||
header: code of op header, not path
|
||||
|
@ -554,7 +556,12 @@ def compile_custom_op(header, source, op_name, warp=True):
|
|||
m = compile_custom_ops([hname, ccname])
|
||||
return getattr(m, op_name)
|
||||
|
||||
def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
||||
@lock.lock_scope()
|
||||
def compile_custom_ops(
|
||||
filenames,
|
||||
extra_flags="",
|
||||
return_module=False,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
|
@ -639,8 +646,10 @@ def compile_custom_ops(filenames, extra_flags="", return_module=False):
|
|||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
if lib_path not in os.sys.path:
|
||||
os.sys.path.append(lib_path)
|
||||
with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
exec(f"import {gen_name}")
|
||||
# unlock scope when initialize
|
||||
with lock.unlock_scope():
|
||||
with jit_utils.import_scope(dlopen_flags):
|
||||
exec(f"import {gen_name}")
|
||||
mod = locals()[gen_name]
|
||||
if return_module:
|
||||
return mod
|
||||
|
@ -796,6 +805,7 @@ dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
|||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
jittor_path = find_jittor_path()
|
||||
check_debug_flags()
|
||||
|
||||
|
@ -962,3 +972,5 @@ flags.jittor_path = jittor_path
|
|||
flags.gdb_path = gdb_path
|
||||
flags.addr2line_path = addr2line_path
|
||||
flags.has_pybt = has_pybt
|
||||
|
||||
core.set_lock_path(lock.lock_path)
|
||||
|
|
|
@ -23,6 +23,7 @@ import jittor as jt
|
|||
|
||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
|
@ -69,9 +70,11 @@ class Dataset(object):
|
|||
def __len__(self):
|
||||
assert self.total_len >= 0
|
||||
assert self.batch_size > 0
|
||||
real_len = (self.total_len-1)//mpi.world_size()+1 if mpi \
|
||||
else self.total_len
|
||||
if self.drop_last:
|
||||
return self.total_len // self.batch_size
|
||||
return (self.total_len-1) // self.batch_size + 1
|
||||
return real_len // self.batch_size
|
||||
return (real_len-1) // self.batch_size + 1
|
||||
|
||||
def set_attrs(self, **kw):
|
||||
'''set attributes of dataset, equivalent to setattr
|
||||
|
@ -130,8 +133,8 @@ class Dataset(object):
|
|||
self.gidc.notify()
|
||||
batch = []
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
|
||||
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size))
|
||||
for i in range(cid*self.batch_size, min(self.real_len, (cid+1)*self.batch_size)):
|
||||
batch.append(self[self.index_list[i]])
|
||||
batch = self.collate_batch(batch)
|
||||
if mp_log_v:
|
||||
|
@ -157,7 +160,7 @@ class Dataset(object):
|
|||
w.buffer.clear()
|
||||
|
||||
def _init_workers(self):
|
||||
self.index_list = mp.Array('i', self.total_len, lock=False)
|
||||
self.index_list = mp.Array('i', self.real_len, lock=False)
|
||||
workers = []
|
||||
# batch id to worker id
|
||||
self.idmap = mp.Array('i', self.batch_len, lock=False)
|
||||
|
@ -174,7 +177,7 @@ class Dataset(object):
|
|||
buffer_size=self.buffer_size)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
||||
def __del__(self):
|
||||
if mp_log_v:
|
||||
|
@ -186,6 +189,25 @@ class Dataset(object):
|
|||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
|
||||
# scatter index_list for all mpi process
|
||||
# scatter rule:
|
||||
# [........]
|
||||
# 000111
|
||||
# 222
|
||||
# make sure each process has the same len
|
||||
if mpi:
|
||||
index_list = np.int32(index_list)
|
||||
mpi.broadcast(index_list, 0)
|
||||
real_len = (self.total_len - 1) // mpi.world_size() + 1
|
||||
offset = mpi.world_rank() * real_len
|
||||
if offset + real_len > self.total_len:
|
||||
offset -= offset + real_len - self.total_len
|
||||
index_list = index_list[offset:offset+real_len]
|
||||
self.real_len = real_len
|
||||
assert real_len == len(index_list)
|
||||
else:
|
||||
self.real_len = self.total_len
|
||||
|
||||
self.batch_len = len(self)
|
||||
if "batch_len" in os.environ:
|
||||
|
|
|
@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||
import numpy as np
|
||||
from collections.abc import Sequence, Mapping
|
||||
from PIL import Image
|
||||
from .. import lock
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -36,7 +37,7 @@ def _progress():
|
|||
|
||||
return bar_update
|
||||
|
||||
|
||||
@lock.lock_scope()
|
||||
def download_url_to_local(url, filename, root_folder, md5):
|
||||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import fcntl
|
||||
import os
|
||||
from jittor_utils import cache_path, LOG
|
||||
|
||||
class Lock:
|
||||
def __init__(self, filename):
|
||||
self.handle = open(filename, 'w')
|
||||
LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}')
|
||||
self.is_locked = False
|
||||
|
||||
def lock(self):
|
||||
fcntl.flock(self.handle, fcntl.LOCK_EX)
|
||||
self.is_locked = True
|
||||
LOG.vv(f'LOCK PID: {os.getpid()}')
|
||||
|
||||
def unlock(self):
|
||||
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
||||
self.is_locked = False
|
||||
LOG.vv(f'UNLOCK PID: {os.getpid()}')
|
||||
|
||||
def __del__(self):
|
||||
self.handle.close()
|
||||
|
||||
|
||||
class _base_scope:
|
||||
'''base_scope for support @xxx syntax'''
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
class lock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if not self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if not self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
class unlock_scope(_base_scope):
|
||||
def __enter__(self):
|
||||
self.is_locked = jittor_lock.is_locked
|
||||
if self.is_locked:
|
||||
jittor_lock.unlock()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if self.is_locked:
|
||||
jittor_lock.lock()
|
||||
|
||||
lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock"))
|
||||
if not os.path.exists(lock_path):
|
||||
LOG.i("Create lock file:", lock_path)
|
||||
try:
|
||||
os.mknod(lock_path)
|
||||
except:
|
||||
pass
|
||||
jittor_lock = Lock(lock_path)
|
|
@ -5,23 +5,33 @@
|
|||
# ***************************************************************
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
import unittest, os
|
||||
|
||||
suffix = "__main__.py"
|
||||
assert __file__.endswith(suffix)
|
||||
test_dir = __file__[:-len(suffix)]
|
||||
import os
|
||||
|
||||
skip_l = int(os.environ.get("test_skip_l", "0"))
|
||||
skip_r = int(os.environ.get("test_skip_r", "1000000"))
|
||||
test_only = None
|
||||
if "test_only" in os.environ:
|
||||
test_only = set(os.environ.get("test_only").split(","))
|
||||
|
||||
test_files = os.listdir(test_dir)
|
||||
for test_file in test_files:
|
||||
test_files = sorted(test_files)
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
for _, test_file in enumerate(test_files):
|
||||
if not test_file.startswith("test_"):
|
||||
continue
|
||||
if _ < skip_l or _ > skip_r:
|
||||
continue
|
||||
test_name = test_file.split(".")[0]
|
||||
exec(f"from . import {test_name}")
|
||||
test_mod = globals()[test_name]
|
||||
print(test_name)
|
||||
for i in dir(test_mod):
|
||||
obj = getattr(test_mod, i)
|
||||
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
|
||||
globals()[test_name+"_"+i] = obj
|
||||
if test_only and test_name not in test_only:
|
||||
continue
|
||||
|
||||
unittest.main()
|
||||
print("Add Test", _, test_name)
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
|
||||
"jittor.test."+test_name))
|
||||
|
||||
unittest.TextTestRunner(verbosity=3).run(suite)
|
|
@ -18,6 +18,7 @@ import pickle as pk
|
|||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
from torch.nn import MaxPool2d, Sequential
|
||||
except:
|
||||
|
@ -45,7 +46,7 @@ def check(jt_model, torch_model, shape, near_data):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestArgPoolOp(unittest.TestCase):
|
||||
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
|
@ -58,15 +59,18 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
def test_cpu_(self):
|
||||
x = jt.random([32, 128, 157, 300])
|
||||
# x = jt.random([32, 128, 157, 300])
|
||||
x = jt.random([4, 128, 157, 300])
|
||||
x = jt.nn.pool(x, 2, "maximum", 0, 2)
|
||||
|
||||
def test_cpu(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
|
||||
shape = [64, 64, 300, 300]
|
||||
# shape = [64, 64, 300, 300]
|
||||
shape = [4, 64, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [32, 128, 157, 300]
|
||||
# shape = [32, 128, 157, 300]
|
||||
shape = [4, 128, 157, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
|
||||
class TestClone(unittest.TestCase):
|
||||
def test(self):
|
||||
jt.clean()
|
||||
b = a = jt.array(1)
|
||||
for i in range(10):
|
||||
b = b.clone()
|
||||
|
|
|
@ -18,11 +18,12 @@ def test_cuda(use_cuda=1):
|
|||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
class TestCuda(unittest.TestCase):
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_flags(self):
|
||||
with jt.var_scope(use_cuda=1):
|
||||
a = jt.random((10, 10))
|
||||
a.sync()
|
||||
a = jt.random((10, 10))
|
||||
a.sync()
|
||||
|
||||
@jt.flag_scope(use_cuda=2)
|
||||
def test_no_cuda_op(self):
|
||||
no_cuda_op = jt.compile_custom_op("""
|
||||
struct NoCudaOp : Op {
|
||||
|
@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase):
|
|||
""",
|
||||
"no_cuda")
|
||||
# force use cuda
|
||||
with jt.var_scope(use_cuda=2):
|
||||
a = no_cuda_op([3,4,5], 'float')
|
||||
expect_error(lambda: a())
|
||||
a = no_cuda_op([3,4,5], 'float')
|
||||
expect_error(lambda: a())
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_custom_op(self):
|
||||
my_op = jt.compile_custom_op("""
|
||||
struct MyCudaOp : Op {
|
||||
|
@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase):
|
|||
#endif // JIT
|
||||
""",
|
||||
"my_cuda")
|
||||
with jt.var_scope(use_cuda=1):
|
||||
a = my_op([3,4,5], 'float')
|
||||
na = a.data
|
||||
a = my_op([3,4,5], 'float')
|
||||
na = a.data
|
||||
assert a.shape == [3,4,5] and a.dtype == 'float'
|
||||
assert (-na.flatten() == range(3*4*5)).all(), na
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase):
|
|||
@jt.flag_scope(use_cuda=1)
|
||||
def test(self):
|
||||
t = cutt_ops.cutt_test("213")
|
||||
jt.sync_all(True)
|
||||
print(t.data)
|
||||
assert t.data == 123
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,27 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
from pathlib import Path
|
||||
|
||||
class TestLock(unittest.TestCase):
|
||||
def test(self):
|
||||
if os.environ.get('lock_full_test', '0') == '1':
|
||||
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
|
||||
assert os.system(f"rm -rf {cache_path}") == 0
|
||||
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
|
||||
else:
|
||||
cmd = f"{sys.executable} -m jittor.test.test_example"
|
||||
print("run cmd twice", cmd)
|
||||
assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
class TestMiscIssue(unittest.TestCase):
|
||||
def test_issue4(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
@ -42,6 +43,7 @@ b.sync()
|
|||
|
||||
def test_mkl_conflict1(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
@ -67,6 +69,7 @@ m(torch.rand(*nchw))
|
|||
|
||||
def test_mkl_conflict2(self):
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
return
|
||||
|
@ -126,5 +129,13 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
assert a.min().data == a.data.min(), (a.min(), a.data.min())
|
||||
assert a.max().data == a.data.max(), (a.max(), a.data.max())
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_pow_grad_nan(self):
|
||||
a = jt.float32([1,-1, -1000.1])
|
||||
da = jt.grad(a**2, a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -10,26 +10,51 @@ import unittest
|
|||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def main():
|
||||
print("test mpi_test")
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
if jt.compile_extern.nccl_ops:
|
||||
print("test test_with_mpi")
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
def test_mpi_test_op(self):
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nccl_with_mpi(self):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
def test_mpi_broadcast(self):
|
||||
for i in range(mpi.world_size()):
|
||||
a = np.zeros(100) + mpi.world_rank()
|
||||
mpi.broadcast(a, i)
|
||||
assert (a == i).all()
|
||||
|
||||
def test_mpi_dataset(self):
|
||||
from jittor.dataset.dataset import Dataset
|
||||
class ToyDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=1024)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index, index*index
|
||||
|
||||
toy = ToyDataset()
|
||||
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
|
||||
|
||||
for _ in range(2):
|
||||
for i,(a,b) in enumerate(toy):
|
||||
assert (a.data*a.data == b.data).all()
|
||||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||
assert (c==a.data).all()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -133,5 +133,34 @@ class TestOpCompiler(unittest.TestCase):
|
|||
expect_error(lambda: jit_precompile(vars, "@if(1)"))
|
||||
expect_error(lambda: jit_precompile(vars, "#define OP1(a,b) a+b\n@expand_macro(OP1,1)"))
|
||||
|
||||
def test_strcmp(self):
|
||||
vars = {"Tx":"float"}
|
||||
check = lambda expr, result: \
|
||||
self.assertEqual(jit_precompile(vars, expr), result)
|
||||
check("@strcmp(aaa,aaa)", "0")
|
||||
check("@strcmp(aaa,bbb)", "-1")
|
||||
check("@strcmp(ccc,bbb)", "1")
|
||||
check("@{@strcmp(aaa,aaa)}", "0")
|
||||
check("@{@strcmp(aaa,bbb)}", "-1")
|
||||
check("@{@strcmp(ccc,bbb)}", "1")
|
||||
|
||||
code = \
|
||||
"""@define(T_NCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
ncclBcast(..., @T_NCCL, ...)
|
||||
"""
|
||||
assert "ncclFloat" in jit_precompile({"Tx":"float"}, code)
|
||||
assert "ncclFloat" in jit_precompile({"Tx":"float32"}, code)
|
||||
assert "ncclFloat64" in jit_precompile({"Tx":"float64"}, code)
|
||||
assert "ncclInt" in jit_precompile({"Tx":"int"}, code)
|
||||
assert "ncclInt" in jit_precompile({"Tx":"int32"}, code)
|
||||
assert "ncclInt64" in jit_precompile({"Tx":"int64"}, code)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -158,9 +158,10 @@ def find_cache_path():
|
|||
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
stderr=sp.PIPE)
|
||||
assert r.returncode == 0
|
||||
bs = r.stdout.decode()
|
||||
bs = r.stdout.decode().splitlines()
|
||||
for b in bs:
|
||||
if b.startswith("* "): break
|
||||
|
||||
cache_name = b[2:]
|
||||
for c in " (){}": cache_name = cache_name.replace(c, "_")
|
||||
except:
|
||||
|
@ -168,10 +169,14 @@ def find_cache_path():
|
|||
for name in cache_name.split("/"):
|
||||
dirs.insert(-1, name)
|
||||
os.environ["cache_name"] = cache_name
|
||||
LOG.v("cache_name", cache_name)
|
||||
for d in dirs:
|
||||
path = os.path.join(path, d)
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
try:
|
||||
os.mkdir(path)
|
||||
except:
|
||||
pass
|
||||
assert os.path.isdir(path)
|
||||
if path not in sys.path:
|
||||
sys.path.append(path)
|
||||
|
|
|
@ -54,17 +54,22 @@ struct EventQueue {
|
|||
static void worker_caller();
|
||||
|
||||
void run_sync(Func func) {
|
||||
// send work to worker and do something by self
|
||||
std::unique_lock<std::mutex> l(mtx);
|
||||
this->func = func;
|
||||
run_sync_done = false;
|
||||
// send func to worker
|
||||
worker.run(worker_caller);
|
||||
while (1) {
|
||||
// check self work or worker's status
|
||||
cv.wait(l);
|
||||
list<Func> ts = move(tasks);
|
||||
l.unlock();
|
||||
// do self works
|
||||
for (auto func : ts)
|
||||
func();
|
||||
l.lock();
|
||||
// worker is finished
|
||||
if (run_sync_done)
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include "lock.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static int lock_fd = -1;
|
||||
|
||||
void set_lock_path(string path) {
|
||||
lock_fd = open(path.c_str(), O_RDWR);
|
||||
ASSERT(lock_fd >= 0);
|
||||
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
|
||||
}
|
||||
|
||||
void lock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_WRLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "LOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
void unlock() {
|
||||
ASSERT(lock_fd >= 0);
|
||||
struct flock lock = {
|
||||
.l_type = F_UNLCK,
|
||||
.l_whence = SEEK_SET,
|
||||
.l_start = 0,
|
||||
.l_len = 0
|
||||
};
|
||||
ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0);
|
||||
LOGvv << "UNLOCK Pid:" << getpid();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Wenyang Zhou <576825820@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(set_lock_path)
|
||||
void set_lock_path(string path);
|
||||
|
||||
void lock();
|
||||
|
||||
void unlock();
|
||||
|
||||
struct lock_guard {
|
||||
inline lock_guard() { lock(); }
|
||||
inline ~lock_guard() { unlock(); }
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -14,6 +14,8 @@
|
|||
#include "misc/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "lock.h"
|
||||
#include "opt/expr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -103,48 +105,6 @@ int OpCompiler::total_member_count() {
|
|||
return member_count;
|
||||
}
|
||||
|
||||
#define FOR_ALL_UOPS(m) \
|
||||
m(!,3) m(~,3)
|
||||
#define FOR_ALL_BOPS(m) \
|
||||
m(*,5) m(/,5) m(%,5) \
|
||||
m(+,6) m(-,6) \
|
||||
m(<<,7) m(>>,7) \
|
||||
m(<,9) m(<=,9) m(>,9) m(>=,9) \
|
||||
m(!=,10) m(==,10) \
|
||||
m(&,11) \
|
||||
m(^,12) \
|
||||
m(|,13) \
|
||||
m(&&,14) \
|
||||
m(||,15)
|
||||
|
||||
#define FOR_ALL_OPS(m) FOR_ALL_UOPS(m) FOR_ALL_BOPS(m)
|
||||
|
||||
inline bool is_unary_op(const string& op) {
|
||||
#define _u(o, _) if (op == #o) return true;
|
||||
FOR_ALL_UOPS(_u);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline int precedence(const string& op) {
|
||||
#define _prior(o, p) if (op == #o) return p;
|
||||
FOR_ALL_OPS(_prior);
|
||||
return 20;
|
||||
}
|
||||
|
||||
inline bool check_precedence(const string& op1, const string& op2) {
|
||||
if (op1 == op2 && is_unary_op(op1)) return false;
|
||||
return precedence(op1) <= precedence(op2);
|
||||
}
|
||||
|
||||
inline int64_t calc_op(int64_t a, int64_t b, const string& op) {
|
||||
#define _calc_b(o, _) if (op == #o) return a o b;
|
||||
FOR_ALL_BOPS(_calc_b);
|
||||
#define _calc_u(o, _) if (op == #o) return o b;
|
||||
FOR_ALL_UOPS(_calc_u);
|
||||
ASSERT(0) << "Unrecognized op " << op;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>& vars) {
|
||||
if (expr.find("@") != string::npos) {
|
||||
string new_expr;
|
||||
|
@ -174,6 +134,22 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
ASSERT(isvar(expr[j]));
|
||||
size_t k=j+1;
|
||||
while (k<expr.size() && isvar(expr[k])) k++;
|
||||
if (k<expr.size() && expr[k]=='(') {
|
||||
// syntax @xx(...)
|
||||
// ij k l
|
||||
size_t l=k+1;
|
||||
int presum = 1;
|
||||
while (l<expr.size() && presum) {
|
||||
if (expr[l] == ')')
|
||||
presum--;
|
||||
else if (expr[l] == '(')
|
||||
presum++;
|
||||
l++;
|
||||
}
|
||||
new_expr += precompile(vars, expr.substr(i, l-i));
|
||||
i = l-1;
|
||||
continue;
|
||||
}
|
||||
string var = expr.substr(j, k-j);
|
||||
auto iter = vars.find(var);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var << " not found." << vars;
|
||||
|
@ -184,68 +160,18 @@ int64_t OpCompiler::eval(const string& expr, const unordered_map<string,string>&
|
|||
}
|
||||
return eval(new_expr, vars);
|
||||
}
|
||||
vector<int64> values = {0};
|
||||
vector<string> ops;
|
||||
auto pop_values_and_calc_op = [&]() {
|
||||
CHECK(ops.size());
|
||||
auto op = ops.back();
|
||||
ops.pop_back();
|
||||
CHECK(values.size());
|
||||
auto val2 = values.back();
|
||||
values.pop_back();
|
||||
auto val1 = val2;
|
||||
if (!is_unary_op(op)) {
|
||||
CHECK(values.size());
|
||||
val1 = values.back();
|
||||
values.pop_back();
|
||||
auto e = expr::make(expr);
|
||||
e->dfs([&](expr::Expr* s) {
|
||||
if (s->is_sym()) {
|
||||
auto iter = vars.find(s->str);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << s->str << " not found.";
|
||||
auto e = expr::make(iter->second);
|
||||
s->swap(e.get());
|
||||
}
|
||||
values.push_back(calc_op(val1, val2, op));
|
||||
};
|
||||
for (size_t i=0; i<expr.size(); i++) {
|
||||
if (expr[i] == ' ')
|
||||
continue;
|
||||
if (expr[i] == '(')
|
||||
ops.push_back(string()+expr[i]);
|
||||
else if (isdigit(expr[i])) {
|
||||
int64_t val = 0;
|
||||
while (i<expr.length() && isdigit(expr[i])) {
|
||||
val = val*10 + (expr[i]-'0');
|
||||
i++;
|
||||
}
|
||||
i--;
|
||||
values.push_back(val);
|
||||
} else if (isvar(expr[i])) {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && isvar(expr[j])) j++;
|
||||
auto var_name = expr.substr(i,j-i);
|
||||
auto iter = vars.find(var_name);
|
||||
ASSERT(iter!=vars.end()) << "Jit var " << var_name << " not found.";
|
||||
try {
|
||||
values.push_back(std::stoll(iter->second));
|
||||
} catch (...) {
|
||||
ASSERT(0) << "'" << iter->second << "' is not integer, expr " << expr;
|
||||
}
|
||||
i = j-1;
|
||||
} else if (expr[i] == ')') {
|
||||
while (ops.size() && ops.back() != "(")
|
||||
pop_values_and_calc_op();
|
||||
ops.pop_back();
|
||||
} else {
|
||||
auto j=i+1;
|
||||
while (j<expr.size() && expr[j] != ' ' &&
|
||||
expr[j] != '!' && expr[j] != '~' &&
|
||||
!isdigit(expr[j]) && !isvar(expr[j]) &&
|
||||
expr[j] != '(' && expr[j] != ')') j++;
|
||||
auto op = expr.substr(i, j-i);
|
||||
while (ops.size() && check_precedence(ops.back(), op))
|
||||
pop_values_and_calc_op();
|
||||
ops.push_back(op);
|
||||
i = j-1;
|
||||
}
|
||||
}
|
||||
while (ops.size())
|
||||
pop_values_and_calc_op();
|
||||
return values.back();
|
||||
});
|
||||
e = e->eval();
|
||||
ASSERT(e->is(expr::_int));
|
||||
return e->as_int();
|
||||
}
|
||||
|
||||
void load_macros(const string& src, unordered_map<string,string>& macros) {
|
||||
|
@ -587,6 +513,19 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "strcmp") {
|
||||
// syntax: @strcmp(s1,s2)
|
||||
// ij k l
|
||||
ASSERT(args.size()==2u)
|
||||
<< "Jit error: strcmp wrong arguments.";
|
||||
auto s1 = precompile(defs, args[0], macros);
|
||||
auto s2 = precompile(defs, args[1], macros);
|
||||
if (s1<s2) new_src += "-1"; else
|
||||
if (s1==s2) new_src += "0"; else
|
||||
new_src += "1";
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (args.size()) {
|
||||
// syntax: @e0(i0,i1,...,in) -> e0p[i0*e0stride0+i1*e0stride1+...]
|
||||
int nid=(int)expr.size();
|
||||
|
@ -672,7 +611,7 @@ string OpCompiler::get_jit_src(Op* op) {
|
|||
else
|
||||
after_include_src += src;
|
||||
}
|
||||
ASSERT(file_exist(src_path));
|
||||
ASSERT(file_exist(src_path)) << src_path;
|
||||
LOGvvv << "Read from" << src_path;
|
||||
string src = read_all(src_path);
|
||||
ASSERT(src.size()) << "Source read failed:" << src_path;
|
||||
|
@ -945,6 +884,7 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
|
|||
}
|
||||
|
||||
jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
||||
jittor::lock_guard lg;
|
||||
OpCompiler oc(op);
|
||||
string* src = &oc.src;
|
||||
string src_after_passes;
|
||||
|
@ -954,8 +894,8 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
|||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
return oc.compile(op->get_jit_key(), *src);
|
||||
auto ret = oc.compile(op->get_jit_key(), *src);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -13,6 +13,8 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_broadcast_to = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
|
@ -122,7 +124,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
if (v_index == 0) {
|
||||
// dout * y * x^(y-1)
|
||||
auto d = make_binary(dout, y, ns_multiply);
|
||||
auto ones = make_number(1, dout);
|
||||
// auto ones = make_number(1, dout);
|
||||
int number = 1;
|
||||
auto ones = make_array(&number, 1, ns_int32);
|
||||
auto y_1 = make_binary(y, ones, ns_subtract);
|
||||
auto x_y_1 = make_binary(x, y_1, ns_pow);
|
||||
return make_binary(d, x_y_1, ns_multiply);
|
||||
|
|
|
@ -235,10 +235,12 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "H Expr matched" << src_h << expr_h;
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
|
||||
auto src_w = expr::make(riop1->indexes[xw]);
|
||||
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
|
||||
return;
|
||||
LOGvvvv << "W Expr matched" << src_w << expr_w;
|
||||
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
|
||||
int stride_h = rh[0]->as_int();
|
||||
int padding_h = -rh[1]->as_int();
|
||||
|
@ -253,7 +255,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
}
|
||||
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
|
||||
|
||||
if (xformat == "bacd" && dilation_h != 1) {
|
||||
LOGvvvv << "mkl not support bacd dilation, continue";
|
||||
continue;
|
||||
}
|
||||
int stride = stride_h;
|
||||
int padding = padding_h;
|
||||
int dilation = dilation_h;
|
||||
|
@ -363,6 +368,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
x = riop1->x;
|
||||
y = riop2->x;
|
||||
bo++;
|
||||
LOGvvvv << "backward_w get stride padding and dilation" << stride << padding << dilation;
|
||||
} else if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
|
@ -438,6 +444,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
w = riop1->x;
|
||||
y = riop2->x;
|
||||
bo+=2;
|
||||
LOGvvvv << "backward_x get stride padding and dilation" << stride << padding << dilation;
|
||||
}
|
||||
|
||||
// TODO: CUDA only support nchw(abcd)
|
||||
|
|
Loading…
Reference in New Issue